From d55448398dc91d8770181b154bf34b5bdab7684a Mon Sep 17 00:00:00 2001 From: Stepan Vladovskiy Date: Wed, 5 Mar 2025 20:08:21 +0000 Subject: [PATCH] feat(search.py): change to txtai server, with ai model. And fix granian workers --- .gitea/workflows/main.yml | 12 +- .gitignore | 3 +- main.py | 9 +- requirements.txt | 4 + server.py | 12 +- services/search.py | 353 +++++++++++++++++++------------------- 6 files changed, 211 insertions(+), 182 deletions(-) diff --git a/.gitea/workflows/main.yml b/.gitea/workflows/main.yml index 18730b95..86ae2988 100644 --- a/.gitea/workflows/main.yml +++ b/.gitea/workflows/main.yml @@ -29,7 +29,17 @@ jobs: if: github.ref == 'refs/heads/dev' uses: dokku/github-action@master with: - branch: 'dev' + branch: 'main' force: true git_remote_url: 'ssh://dokku@v2.discours.io:22/core' ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Push to dokku for staging branch + if: github.ref == 'refs/heads/staging' + uses: dokku/github-action@master + with: + branch: 'main' + force: true + git_remote_url: 'ssh://dokku@staging.discours.io:22/core + ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }} + git_push_flags: '--force' \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4db9e7e4..502d180d 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ views.json *.key *.crt *cache.json -.cursor \ No newline at end of file +.cursor +.devcontainer/ diff --git a/main.py b/main.py index ff64c974..7c4a722f 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,8 @@ from cache.revalidator import revalidation_manager from services.exception import ExceptionHandlerMiddleware from services.redis import redis from services.schema import create_all_tables, resolvers -from services.search import search_service +#from services.search import search_service +from services.search import search_service, initialize_search_index from services.viewed import ViewedStorage from services.webhook import WebhookEndpoint, create_webhook_endpoint from settings import DEV_SERVER_PID_FILE_NAME, MODE @@ -47,6 +48,12 @@ async def lifespan(_app): start(), revalidation_manager.start(), ) + + # After basic initialization is complete, fetch shouts and initialize search + from services.db import fetch_all_shouts # Import your database access function + all_shouts = await fetch_all_shouts() # Replace with your actual function + await initialize_search_index(all_shouts) + yield finally: tasks = [redis.disconnect(), ViewedStorage.stop(), revalidation_manager.stop()] diff --git a/requirements.txt b/requirements.txt index 56b09175..ccab19f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,10 @@ gql ariadne granian +# NLP and search +txtai[embeddings] +sentence-transformers + pydantic fakeredis pytest diff --git a/server.py b/server.py index 30009c89..e34609b1 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,8 @@ from pathlib import Path from granian.constants import Interfaces from granian.log import LogLevels -from granian.server import Granian +from granian.server import Server +from sentence_transformers import SentenceTransformer from settings import PORT from utils.logger import root_logger as logger @@ -11,12 +12,17 @@ from utils.logger import root_logger as logger if __name__ == "__main__": logger.info("started") try: - granian_instance = Granian( + # Preload the model before starting the server + logger.info("Loading sentence transformer model...") + model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') + logger.info("Model loaded successfully!") + + granian_instance = Server( "main:app", address="0.0.0.0", port=PORT, interface=Interfaces.ASGI, - threads=4, + workers=4, websockets=False, log_level=LogLevels.debug, backlog=2048, diff --git a/services/search.py b/services/search.py index 9c9b13e9..b8b97b60 100644 --- a/services/search.py +++ b/services/search.py @@ -2,8 +2,9 @@ import asyncio import json import logging import os +import concurrent.futures -from opensearchpy import OpenSearch +from txtai.embeddings import Embeddings from services.redis import redis from utils.encoders import CustomJSONEncoder @@ -12,220 +13,220 @@ from utils.encoders import CustomJSONEncoder logger = logging.getLogger("search") logger.setLevel(logging.WARNING) -ELASTIC_HOST = os.environ.get("ELASTIC_HOST", "").replace("https://", "") -ELASTIC_USER = os.environ.get("ELASTIC_USER", "") -ELASTIC_PASSWORD = os.environ.get("ELASTIC_PASSWORD", "") -ELASTIC_PORT = os.environ.get("ELASTIC_PORT", 9200) -ELASTIC_URL = os.environ.get( - "ELASTIC_URL", - f"https://{ELASTIC_USER}:{ELASTIC_PASSWORD}@{ELASTIC_HOST}:{ELASTIC_PORT}", -) REDIS_TTL = 86400 # 1 день в секундах -index_settings = { - "settings": { - "index": {"number_of_shards": 1, "auto_expand_replicas": "0-all"}, - "analysis": { - "analyzer": { - "ru": { - "tokenizer": "standard", - "filter": ["lowercase", "ru_stop", "ru_stemmer"], - } - }, - "filter": { - "ru_stemmer": {"type": "stemmer", "language": "russian"}, - "ru_stop": {"type": "stop", "stopwords": "_russian_"}, - }, - }, - }, - "mappings": { - "properties": { - "body": {"type": "text", "analyzer": "ru"}, - "title": {"type": "text", "analyzer": "ru"}, - "subtitle": {"type": "text", "analyzer": "ru"}, - "lead": {"type": "text", "analyzer": "ru"}, - "media": {"type": "text", "analyzer": "ru"}, - } - }, -} - -expected_mapping = index_settings["mappings"] - -# Создание цикла событий -search_loop = asyncio.get_event_loop() - -# В начале файла добавим флаг -SEARCH_ENABLED = bool(os.environ.get("ELASTIC_HOST", "")) - - -def get_indices_stats(): - indices_stats = search_service.client.cat.indices(format="json") - for index_info in indices_stats: - index_name = index_info["index"] - if not index_name.startswith("."): - index_health = index_info["health"] - index_status = index_info["status"] - pri_shards = index_info["pri"] - rep_shards = index_info["rep"] - docs_count = index_info["docs.count"] - docs_deleted = index_info["docs.deleted"] - store_size = index_info["store.size"] - pri_store_size = index_info["pri.store.size"] - - logger.info(f"Index: {index_name}") - logger.info(f"Health: {index_health}") - logger.info(f"Status: {index_status}") - logger.info(f"Primary Shards: {pri_shards}") - logger.info(f"Replica Shards: {rep_shards}") - logger.info(f"Documents Count: {docs_count}") - logger.info(f"Deleted Documents: {docs_deleted}") - logger.info(f"Store Size: {store_size}") - logger.info(f"Primary Store Size: {pri_store_size}") +# Configuration for txtai search +SEARCH_ENABLED = bool(os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"]) +# Thread executor for non-blocking initialization +thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) class SearchService: def __init__(self, index_name="search_index"): logger.info("Инициализируем поиск...") self.index_name = index_name - self.client = None - self.lock = asyncio.Lock() - - # Инициализация клиента OpenSearch только если поиск включен - if SEARCH_ENABLED: - try: - self.client = OpenSearch( - hosts=[{"host": ELASTIC_HOST, "port": ELASTIC_PORT}], - http_compress=True, - http_auth=(ELASTIC_USER, ELASTIC_PASSWORD), - use_ssl=True, - verify_certs=False, - ssl_assert_hostname=False, - ssl_show_warn=False, - ) - logger.info("Клиент OpenSearch.org подключен") - search_loop.create_task(self.check_index()) - except Exception as exc: - logger.warning(f"Поиск отключен из-за ошибки подключения: {exc}") - self.client = None - else: - logger.info("Поиск отключен (ELASTIC_HOST не установлен)") + self.embeddings = None + self._initialization_future = None + self.available = SEARCH_ENABLED + + if not self.available: + logger.info("Поиск отключен (SEARCH_ENABLED = False)") + return + + # Initialize embeddings in background thread + self._initialization_future = thread_executor.submit(self._init_embeddings) + + def _init_embeddings(self): + """Initialize txtai embeddings in a background thread""" + try: + # Use the same model as in TopicClassifier + model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + + # Configure embeddings with content storage and quantization for lower memory usage + self.embeddings = Embeddings({ + "path": model_path, + "content": True, + "quantize": True + }) + logger.info("txtai embeddings initialized successfully") + return True + except Exception as e: + logger.error(f"Failed to initialize txtai embeddings: {e}") + self.available = False + return False async def info(self): - if not SEARCH_ENABLED: + """Return information about search service""" + if not self.available: return {"status": "disabled"} try: - return get_indices_stats() + if not self.is_ready(): + return {"status": "initializing", "model": "paraphrase-multilingual-mpnet-base-v2"} + + return { + "status": "active", + "count": len(self.embeddings) if self.embeddings else 0, + "model": "paraphrase-multilingual-mpnet-base-v2" + } except Exception as e: logger.error(f"Failed to get search info: {e}") return {"status": "error", "message": str(e)} - def delete_index(self): - if self.client: - logger.warning(f"[!!!] Удаляем индекс {self.index_name}") - self.client.indices.delete(index=self.index_name, ignore_unavailable=True) - - def create_index(self): - if self.client: - logger.info(f"Создается индекс: {self.index_name}") - self.client.indices.create(index=self.index_name, body=index_settings) - logger.info(f"Индекс {self.index_name} создан") - - async def check_index(self): - if self.client: - logger.info(f"Проверяем индекс {self.index_name}...") - if not self.client.indices.exists(index=self.index_name): - self.create_index() - self.client.indices.put_mapping(index=self.index_name, body=expected_mapping) - else: - logger.info(f"Найден существующий индекс {self.index_name}") - # Проверка и обновление структуры индекса, если необходимо - result = self.client.indices.get_mapping(index=self.index_name) - if isinstance(result, str): - result = json.loads(result) - if isinstance(result, dict): - mapping = result.get(self.index_name, {}).get("mappings") - logger.info(f"Найдена структура индексации: {mapping['properties'].keys()}") - expected_keys = expected_mapping["properties"].keys() - if mapping and mapping["properties"].keys() != expected_keys: - logger.info(f"Ожидаемая структура индексации: {expected_mapping}") - logger.warning("[!!!] Требуется переиндексация всех данных") - self.delete_index() - self.client = None - else: - logger.error("клиент не инициализован, невозможно проверить индекс") + def is_ready(self): + """Check if embeddings are fully initialized and ready""" + return self.embeddings is not None and self.available def index(self, shout): - if not SEARCH_ENABLED: + """Index a single document""" + if not self.available: return - if self.client: - logger.info(f"Индексируем пост {shout.id}") - index_body = { - "body": shout.body, - "title": shout.title, - "subtitle": shout.subtitle, - "lead": shout.lead, - "media": shout.media, - } - asyncio.create_task(self.perform_index(shout, index_body)) + logger.info(f"Индексируем пост {shout.id}") + + # Start in background to not block + asyncio.create_task(self.perform_index(shout)) - async def perform_index(self, shout, index_body): - if self.client: - try: - await asyncio.wait_for( - self.client.index(index=self.index_name, id=str(shout.id), body=index_body), timeout=40.0 - ) - except asyncio.TimeoutError: - logger.error(f"Indexing timeout for shout {shout.id}") - except Exception as e: - logger.error(f"Indexing error for shout {shout.id}: {e}") + async def perform_index(self, shout): + """Actually perform the indexing operation""" + if not self.is_ready(): + # If embeddings not ready, wait for initialization + if self._initialization_future and not self._initialization_future.done(): + try: + # Wait for initialization to complete with timeout + await asyncio.get_event_loop().run_in_executor( + None, lambda: self._initialization_future.result(timeout=30)) + except Exception as e: + logger.error(f"Embeddings initialization failed: {e}") + return + + if not self.is_ready(): + logger.error(f"Cannot index shout {shout.id}: embeddings not ready") + return + + try: + # Combine all text fields + text = " ".join(filter(None, [ + shout.title or "", + shout.subtitle or "", + shout.lead or "", + shout.body or "", + shout.media or "" + ])) + + # Use upsert for individual documents + await asyncio.get_event_loop().run_in_executor( + None, + lambda: self.embeddings.upsert([(str(shout.id), text, None)]) + ) + logger.info(f"Пост {shout.id} успешно индексирован") + except Exception as e: + logger.error(f"Indexing error for shout {shout.id}: {e}") + + async def bulk_index(self, shouts): + """Index multiple documents at once""" + if not self.available or not shouts: + return + + if not self.is_ready(): + # Wait for initialization if needed + if self._initialization_future and not self._initialization_future.done(): + try: + await asyncio.get_event_loop().run_in_executor( + None, lambda: self._initialization_future.result(timeout=30)) + except Exception as e: + logger.error(f"Embeddings initialization failed: {e}") + return + + if not self.is_ready(): + logger.error("Cannot perform bulk indexing: embeddings not ready") + return + + documents = [] + for shout in shouts: + text = " ".join(filter(None, [ + shout.title or "", + shout.subtitle or "", + shout.lead or "", + shout.body or "", + shout.media or "" + ])) + documents.append((str(shout.id), text, None)) + + try: + await asyncio.get_event_loop().run_in_executor( + None, lambda: self.embeddings.upsert(documents)) + logger.info(f"Bulk indexed {len(documents)} documents") + except Exception as e: + logger.error(f"Bulk indexing error: {e}") async def search(self, text, limit, offset): - if not SEARCH_ENABLED: + """Search documents""" + if not self.available: return [] - + + # Check Redis cache first + redis_key = f"search:{text}:{offset}+{limit}" + cached = await redis.get(redis_key) + if cached: + return json.loads(cached) + logger.info(f"Ищем: {text} {offset}+{limit}") - search_body = { - "query": {"multi_match": {"query": text, "fields": ["title", "lead", "subtitle", "body", "media"]}} - } - - if self.client: - search_response = self.client.search( - index=self.index_name, - body=search_body, - size=limit, - from_=offset, - _source=False, - _source_excludes=["title", "body", "subtitle", "media", "lead", "_index"], - ) - hits = search_response["hits"]["hits"] - results = [{"id": hit["_id"], "score": hit["_score"]} for hit in hits] - - # если результаты не пустые - if results: - # Кэширование в Redis с TTL - redis_key = f"search:{text}:{offset}+{limit}" + + if not self.is_ready(): + # Wait for initialization if needed + if self._initialization_future and not self._initialization_future.done(): + try: + await asyncio.get_event_loop().run_in_executor( + None, lambda: self._initialization_future.result(timeout=30)) + except Exception as e: + logger.error(f"Embeddings initialization failed: {e}") + return [] + + if not self.is_ready(): + logger.error("Cannot search: embeddings not ready") + return [] + + try: + # Search with txtai (need to request more to handle offset) + total = offset + limit + results = await asyncio.get_event_loop().run_in_executor( + None, lambda: self.embeddings.search(text, total)) + + # Apply offset and convert to the expected format + results = results[offset:offset+limit] + formatted_results = [{"id": doc_id, "score": float(score)} for score, doc_id in results] + + # Cache results + if formatted_results: await redis.execute( "SETEX", redis_key, REDIS_TTL, - json.dumps(results, cls=CustomJSONEncoder), + json.dumps(formatted_results, cls=CustomJSONEncoder), ) - return results - return [] + return formatted_results + except Exception as e: + logger.error(f"Search error: {e}") + return [] +# Create the search service singleton search_service = SearchService() +# Keep the API exactly the same to maintain compatibility async def search_text(text: str, limit: int = 50, offset: int = 0): payload = [] - if search_service.client: - # Использование метода search_post из OpenSearchService + if search_service.available: payload = await search_service.search(text, limit, offset) return payload -# Проверить что URL корректный -OPENSEARCH_URL = os.getenv("OPENSEARCH_URL", "rc1a-3n5pi3bhuj9gieel.mdb.yandexcloud.net") +# Function to initialize search with existing data +async def initialize_search_index(shouts_data): + """Initialize search index with existing data during application startup""" + if SEARCH_ENABLED: + logger.info("Initializing search index with existing data...") + await search_service.bulk_index(shouts_data) + logger.info(f"Search index initialized with {len(shouts_data)} documents")