This commit is contained in:
@@ -6,7 +6,17 @@ import warnings
|
||||
from typing import Any, Callable, Dict, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Column,
|
||||
Engine,
|
||||
Integer,
|
||||
create_engine,
|
||||
event,
|
||||
exc,
|
||||
func,
|
||||
inspect,
|
||||
)
|
||||
from sqlalchemy.orm import Session, configure_mappers, declarative_base
|
||||
from sqlalchemy.sql.schema import Table
|
||||
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import concurrent.futures
|
||||
from typing import Dict, Tuple, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from txtai.embeddings import Embeddings
|
||||
|
||||
from services.logger import root_logger as logger
|
||||
|
||||
|
||||
class TopicClassifier:
|
||||
def __init__(self, shouts_by_topic: Dict[str, str], publications: List[Dict[str, str]]):
|
||||
"""
|
||||
@@ -32,27 +35,21 @@ class TopicClassifier:
|
||||
Подготавливает векторные представления для тем и поиска.
|
||||
"""
|
||||
logger.info("Начинается подготовка векторных представлений...")
|
||||
|
||||
|
||||
# Модель для русского языка
|
||||
# TODO: model local caching
|
||||
model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
|
||||
|
||||
# Инициализируем embeddings для классификации тем
|
||||
self.topic_embeddings = Embeddings(path=model_path)
|
||||
topic_documents = [
|
||||
(topic, text)
|
||||
for topic, text in self.shouts_by_topic.items()
|
||||
]
|
||||
topic_documents = [(topic, text) for topic, text in self.shouts_by_topic.items()]
|
||||
self.topic_embeddings.index(topic_documents)
|
||||
|
||||
|
||||
# Инициализируем embeddings для поиска публикаций
|
||||
self.search_embeddings = Embeddings(path=model_path)
|
||||
search_documents = [
|
||||
(str(pub['id']), f"{pub['title']} {pub['text']}")
|
||||
for pub in self.publications
|
||||
]
|
||||
search_documents = [(str(pub["id"]), f"{pub['title']} {pub['text']}") for pub in self.publications]
|
||||
self.search_embeddings.index(search_documents)
|
||||
|
||||
|
||||
logger.info("Подготовка векторных представлений завершена.")
|
||||
|
||||
def predict_topic(self, text: str) -> Tuple[float, str]:
|
||||
@@ -66,13 +63,13 @@ class TopicClassifier:
|
||||
if not self.is_ready():
|
||||
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
|
||||
return 0.0, "unknown"
|
||||
|
||||
|
||||
try:
|
||||
# Ищем наиболее похожую тему
|
||||
results = self.topic_embeddings.search(text, 1)
|
||||
if not results:
|
||||
return 0.0, "unknown"
|
||||
|
||||
|
||||
score, topic = results[0]
|
||||
return float(score), topic
|
||||
|
||||
@@ -92,25 +89,19 @@ class TopicClassifier:
|
||||
if not self.is_ready():
|
||||
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
# Ищем похожие публикации
|
||||
results = self.search_embeddings.search(query, limit)
|
||||
|
||||
|
||||
# Формируем результаты
|
||||
found_publications = []
|
||||
for score, pub_id in results:
|
||||
# Находим публикацию по id
|
||||
publication = next(
|
||||
(pub for pub in self.publications if str(pub['id']) == pub_id),
|
||||
None
|
||||
)
|
||||
publication = next((pub for pub in self.publications if str(pub["id"]) == pub_id), None)
|
||||
if publication:
|
||||
found_publications.append({
|
||||
**publication,
|
||||
'relevance': float(score)
|
||||
})
|
||||
|
||||
found_publications.append({**publication, "relevance": float(score)})
|
||||
|
||||
return found_publications
|
||||
|
||||
except Exception as e:
|
||||
@@ -137,6 +128,7 @@ class TopicClassifier:
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
|
||||
# Пример использования:
|
||||
"""
|
||||
shouts_by_topic = {
|
||||
@@ -176,4 +168,3 @@ for pub in similar_publications:
|
||||
print(f"Заголовок: {pub['title']}")
|
||||
print(f"Текст: {pub['text'][:100]}...")
|
||||
"""
|
||||
|
||||
|
@@ -43,7 +43,6 @@ async def request_graphql_data(gql, url=AUTH_URL, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def create_all_tables():
|
||||
"""Create all database tables in the correct order."""
|
||||
from orm import author, community, draft, notification, reaction, shout, topic, user
|
||||
@@ -54,26 +53,21 @@ def create_all_tables():
|
||||
author.Author, # Базовая таблица
|
||||
community.Community, # Базовая таблица
|
||||
topic.Topic, # Базовая таблица
|
||||
|
||||
# Связи для базовых таблиц
|
||||
author.AuthorFollower, # Зависит от Author
|
||||
community.CommunityFollower, # Зависит от Community
|
||||
topic.TopicFollower, # Зависит от Topic
|
||||
|
||||
# Черновики (теперь без зависимости от Shout)
|
||||
draft.Draft, # Зависит только от Author
|
||||
draft.DraftAuthor, # Зависит от Draft и Author
|
||||
draft.DraftTopic, # Зависит от Draft и Topic
|
||||
|
||||
# Основные таблицы контента
|
||||
shout.Shout, # Зависит от Author и Draft
|
||||
shout.ShoutAuthor, # Зависит от Shout и Author
|
||||
shout.ShoutTopic, # Зависит от Shout и Topic
|
||||
|
||||
# Реакции
|
||||
reaction.Reaction, # Зависит от Author и Shout
|
||||
shout.ShoutReactionsFollower, # Зависит от Shout и Reaction
|
||||
|
||||
# Дополнительные таблицы
|
||||
author.AuthorRating, # Зависит от Author
|
||||
notification.Notification, # Зависит от Author
|
||||
@@ -87,7 +81,7 @@ def create_all_tables():
|
||||
for model in models_in_order:
|
||||
try:
|
||||
create_table_if_not_exists(session.get_bind(), model)
|
||||
logger.info(f"Created or verified table: {model.__tablename__}")
|
||||
# logger.info(f"Created or verified table: {model.__tablename__}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating table {model.__tablename__}: {e}")
|
||||
raise
|
||||
raise
|
||||
|
@@ -7,7 +7,12 @@ from typing import Dict
|
||||
|
||||
# ga
|
||||
from google.analytics.data_v1beta import BetaAnalyticsDataClient
|
||||
from google.analytics.data_v1beta.types import DateRange, Dimension, Metric, RunReportRequest
|
||||
from google.analytics.data_v1beta.types import (
|
||||
DateRange,
|
||||
Dimension,
|
||||
Metric,
|
||||
RunReportRequest,
|
||||
)
|
||||
from google.analytics.data_v1beta.types import Filter as GAFilter
|
||||
|
||||
from orm.author import Author
|
||||
|
Reference in New Issue
Block a user