0.4.9-b
All checks were successful
Deploy on push / deploy (push) Successful in 2m38s

This commit is contained in:
Untone 2025-02-09 22:26:50 +03:00
parent 37a9a284ef
commit 4a835bbfba
17 changed files with 1082 additions and 36 deletions

View File

@ -2,7 +2,11 @@
- `Shout.draft` field added - `Shout.draft` field added
- `Draft` entity added - `Draft` entity added
- `create_draft`, `update_draft`, `delete_draft` mutations and resolvers added - `create_draft`, `update_draft`, `delete_draft` mutations and resolvers added
- `get_shout_drafts` resolver updated - `create_shout`, `update_shout`, `delete_shout` mutations removed from GraphQL API
- `load_drafts` resolver implemented
- `publish_` and `unpublish_` mutations and resolvers added
- `create_`, `update_`, `delete_` mutations and resolvers added for `Draft` entity
- tests with pytest for auth, shouts, drafts
#### [0.4.8] - 2025-02-03 #### [0.4.8] - 2025-02-03
- `Reaction.deleted_at` filter on `update_reaction` resolver added - `Reaction.deleted_at` filter on `update_reaction` resolver added

View File

@ -2,13 +2,13 @@ from binascii import hexlify
from hashlib import sha256 from hashlib import sha256
# from base.exceptions import InvalidPassword, InvalidToken # from base.exceptions import InvalidPassword, InvalidToken
from base.orm import local_session from services.db import local_session
from jwt import DecodeError, ExpiredSignatureError from auth.exceptions import ExpiredToken, InvalidToken
from passlib.hash import bcrypt from passlib.hash import bcrypt
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from auth.tokenstorage import TokenStorage from auth.tokenstorage import TokenStorage
from orm import User from orm.user import User
class Password: class Password:
@ -79,10 +79,10 @@ class Identity:
if not await TokenStorage.exist(f"{payload.user_id}-{payload.username}-{token}"): if not await TokenStorage.exist(f"{payload.user_id}-{payload.username}-{token}"):
# raise InvalidToken("Login token has expired, please login again") # raise InvalidToken("Login token has expired, please login again")
return {"error": "Token has expired"} return {"error": "Token has expired"}
except ExpiredSignatureError: except ExpiredToken:
# raise InvalidToken("Login token has expired, please try again") # raise InvalidToken("Login token has expired, please try again")
return {"error": "Token has expired"} return {"error": "Token has expired"}
except DecodeError: except InvalidToken:
# raise InvalidToken("token format error") from e # raise InvalidToken("token format error") from e
return {"error": "Token format error"} return {"error": "Token format error"}
with local_session() as session: with local_session() as session:

View File

@ -1,7 +1,7 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from base.redis import redis from services.redis import redis
from validations.auth import AuthInput from auth.validations import AuthInput
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN

103
auth/validations.py Normal file
View File

@ -0,0 +1,103 @@
import re
from datetime import datetime
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator
# RFC 5322 compliant email regex pattern
EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
class AuthInput(BaseModel):
"""Base model for authentication input validation"""
user_id: str = Field(description="Unique user identifier")
username: str = Field(min_length=2, max_length=50)
token: str = Field(min_length=32)
@field_validator('user_id')
@classmethod
def validate_user_id(cls, v: str) -> str:
if not v.strip():
raise ValueError("user_id cannot be empty")
return v
class UserRegistrationInput(BaseModel):
"""Validation model for user registration"""
email: str = Field(max_length=254) # Max email length per RFC 5321
password: str = Field(min_length=8, max_length=100)
name: str = Field(min_length=2, max_length=50)
@field_validator('email')
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email format"""
if not re.match(EMAIL_PATTERN, v):
raise ValueError("Invalid email format")
return v.lower()
@field_validator('password')
@classmethod
def validate_password_strength(cls, v: str) -> str:
"""Validate password meets security requirements"""
if not any(c.isupper() for c in v):
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError("Password must contain at least one number")
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in v):
raise ValueError("Password must contain at least one special character")
return v
class UserLoginInput(BaseModel):
"""Validation model for user login"""
email: str = Field(max_length=254)
password: str = Field(min_length=8, max_length=100)
@field_validator('email')
@classmethod
def validate_email(cls, v: str) -> str:
if not re.match(EMAIL_PATTERN, v):
raise ValueError("Invalid email format")
return v.lower()
class TokenPayload(BaseModel):
"""Validation model for JWT token payload"""
user_id: str
username: str
exp: datetime
iat: datetime
scopes: Optional[List[str]] = []
class OAuthInput(BaseModel):
"""Validation model for OAuth input"""
provider: str = Field(pattern='^(google|github|facebook)$')
code: str
redirect_uri: Optional[str] = None
@field_validator('provider')
@classmethod
def validate_provider(cls, v: str) -> str:
valid_providers = ['google', 'github', 'facebook']
if v.lower() not in valid_providers:
raise ValueError(f"Provider must be one of: {', '.join(valid_providers)}")
return v.lower()
class AuthResponse(BaseModel):
"""Validation model for authentication responses"""
success: bool
token: Optional[str] = None
error: Optional[str] = None
user: Optional[Dict[str, Union[str, int, bool]]] = None
@field_validator('error')
@classmethod
def validate_error_if_not_success(cls, v: Optional[str], info) -> Optional[str]:
if not info.data.get('success') and not v:
raise ValueError("Error message required when success is False")
return v
@field_validator('token')
@classmethod
def validate_token_if_success(cls, v: Optional[str], info) -> Optional[str]:
if info.data.get('success') and not v:
raise ValueError("Token required when success is True")
return v

176
orm/rbac.py Normal file
View File

@ -0,0 +1,176 @@
from services.db import REGISTRY, Base, local_session
from utils.logger import root_logger as logger
from sqlalchemy.types import TypeDecorator
from sqlalchemy.types import String
from sqlalchemy import Column, ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import relationship
class ClassType(TypeDecorator):
impl = String
@property
def python_type(self):
return NotImplemented
def process_literal_param(self, value, dialect):
return NotImplemented
def process_bind_param(self, value, dialect):
return value.__name__ if isinstance(value, type) else str(value)
def process_result_value(self, value, dialect):
class_ = REGISTRY.get(value)
if class_ is None:
logger.warn(f"Can't find class <{value}>,find it yourself!", stacklevel=2)
return class_
class Role(Base):
__tablename__ = "role"
name = Column(String, nullable=False, comment="Role Name")
desc = Column(String, nullable=True, comment="Role Description")
community = Column(
ForeignKey("community.id", ondelete="CASCADE"),
nullable=False,
comment="Community",
)
permissions = relationship(lambda: Permission)
@staticmethod
def init_table():
with local_session() as session:
r = session.query(Role).filter(Role.name == "author").first()
if r:
Role.default_role = r
return
r1 = Role.create(
name="author",
desc="Role for an author",
community=1,
)
session.add(r1)
Role.default_role = r1
r2 = Role.create(
name="reader",
desc="Role for a reader",
community=1,
)
session.add(r2)
r3 = Role.create(
name="expert",
desc="Role for an expert",
community=1,
)
session.add(r3)
r4 = Role.create(
name="editor",
desc="Role for an editor",
community=1,
)
session.add(r4)
class Operation(Base):
__tablename__ = "operation"
name = Column(String, nullable=False, unique=True, comment="Operation Name")
@staticmethod
def init_table():
with local_session() as session:
for name in ["create", "update", "delete", "load"]:
"""
* everyone can:
- load shouts
- load topics
- load reactions
- create an account to become a READER
* readers can:
- update and delete their account
- load chats
- load messages
- create reaction of some shout's author allowed kinds
- create shout to become an AUTHOR
* authors can:
- update and delete their shout
- invite other authors to edit shout and chat
- manage allowed reactions for their shout
* pros can:
- create/update/delete their community
- create/update/delete topics for their community
"""
op = session.query(Operation).filter(Operation.name == name).first()
if not op:
op = Operation.create(name=name)
session.add(op)
session.commit()
class Resource(Base):
__tablename__ = "resource"
resourceClass = Column(String, nullable=False, unique=True, comment="Resource class")
name = Column(String, nullable=False, unique=True, comment="Resource name")
# TODO: community = Column(ForeignKey())
@staticmethod
def init_table():
with local_session() as session:
for res in [
"shout",
"topic",
"reaction",
"chat",
"message",
"invite",
"community",
"user",
]:
r = session.query(Resource).filter(Resource.name == res).first()
if not r:
r = Resource.create(name=res, resourceClass=res)
session.add(r)
session.commit()
class Permission(Base):
__tablename__ = "permission"
__table_args__ = (
UniqueConstraint("role", "operation", "resource"),
{"extend_existing": True},
)
role: Column = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role")
operation: Column = Column(
ForeignKey("operation.id", ondelete="CASCADE"),
nullable=False,
comment="Operation",
)
resource: Column = Column(
ForeignKey("resource.id", ondelete="CASCADE"),
nullable=False,
comment="Resource",
)
# if __name__ == "__main__":
# Base.metadata.create_all(engine)
# ops = [
# Permission(role=1, operation=1, resource=1),
# Permission(role=1, operation=2, resource=1),
# Permission(role=1, operation=3, resource=1),
# Permission(role=1, operation=4, resource=1),
# Permission(role=2, operation=4, resource=1),
# ]
# global_session.add_all(ops)
# global_session.commit()

105
orm/user.py Normal file
View File

@ -0,0 +1,105 @@
from sqlalchemy import JSON as JSONType
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func
from sqlalchemy.orm import relationship
from services.db import Base, local_session
from orm.rbac import Role
class UserRating(Base):
__tablename__ = "user_rating"
id = None
rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
value: Column = Column(Integer)
@staticmethod
def init_table():
pass
class UserRole(Base):
__tablename__ = "user_role"
id = None
user = Column(ForeignKey("user.id"), primary_key=True, index=True)
role = Column(ForeignKey("role.id"), primary_key=True, index=True)
class AuthorFollower(Base):
__tablename__ = "author_follower"
id = None
follower: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
author: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
createdAt = Column(
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at"
)
auto = Column(Boolean, nullable=False, default=False)
class User(Base):
__tablename__ = "user"
default_user = None
email = Column(String, unique=True, nullable=False, comment="Email")
username = Column(String, nullable=False, comment="Login")
password = Column(String, nullable=True, comment="Password")
bio = Column(String, nullable=True, comment="Bio") # status description
about = Column(String, nullable=True, comment="About") # long and formatted
userpic = Column(String, nullable=True, comment="Userpic")
name = Column(String, nullable=True, comment="Display name")
slug = Column(String, unique=True, comment="User's slug")
muted = Column(Boolean, default=False)
emailConfirmed = Column(Boolean, default=False)
createdAt = Column(
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at"
)
lastSeen = Column(
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at"
)
deletedAt = Column(DateTime(timezone=True), nullable=True, comment="Deleted at")
links = Column(JSONType, nullable=True, comment="Links")
oauth = Column(String, nullable=True)
ratings = relationship(UserRating, foreign_keys=UserRating.user)
roles = relationship(lambda: Role, secondary=UserRole.__tablename__)
oid = Column(String, nullable=True)
@staticmethod
def init_table():
with local_session() as session:
default = session.query(User).filter(User.slug == "anonymous").first()
if not default:
default_dict = {
"email": "noreply@discours.io",
"username": "noreply@discours.io",
"name": "Аноним",
"slug": "anonymous",
}
default = User.create(**default_dict)
session.add(default)
discours_dict = {
"email": "welcome@discours.io",
"username": "welcome@discours.io",
"name": "Дискурс",
"slug": "discours",
}
discours = User.create(**discours_dict)
session.add(discours)
session.commit()
User.default_user = default
def get_permission(self):
scope = {}
for role in self.roles:
for p in role.permissions:
if p.resource not in scope:
scope[p.resource] = set()
scope[p.resource].add(p.operation)
print(scope)
return scope
# if __name__ == "__main__":
# print(User.get_permission(user_id=1))

View File

@ -26,6 +26,8 @@ fakeredis = "^2.25.1"
pydantic = "^2.9.2" pydantic = "^2.9.2"
jwt = "^1.3.1" jwt = "^1.3.1"
authlib = "^1.3.2" authlib = "^1.3.2"
passlib = "^1.7.4"
bcrypt = "^4.2.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
@ -34,13 +36,15 @@ isort = "^5.13.2"
pydantic = "^2.9.2" pydantic = "^2.9.2"
pytest = "^8.3.4" pytest = "^8.3.4"
mypy = "^1.15.0" mypy = "^1.15.0"
pytest-asyncio = "^0.23.5"
pytest-cov = "^4.1.0"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.pyright] [tool.pyright]
venvPath = "." venvPath = "venv"
venv = "venv" venv = "venv"
[tool.isort] [tool.isort]
@ -49,5 +53,10 @@ include_trailing_comma = true
force_grid_wrap = 0 force_grid_wrap = 0
line_length = 120 line_length = 120
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
venv = "venv"
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120

View File

@ -1,16 +1,21 @@
import time import time
from importlib import invalidate_caches from orm.topic import Topic
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.sql import and_
from cache.cache import invalidate_shout_related_cache, invalidate_shouts_cache from cache.cache import (
cache_author, cache_by_id, cache_topic,
invalidate_shout_related_cache, invalidate_shouts_cache
)
from orm.author import Author from orm.author import Author
from orm.draft import Draft from orm.draft import Draft
from orm.shout import Shout from orm.shout import Shout, ShoutAuthor, ShoutTopic
from services.auth import login_required from services.auth import login_required
from services.db import local_session from services.db import local_session
from services.schema import mutation, query from services.schema import mutation, query
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
from services.notify import notify_shout
from services.search import search_service
@query.field("load_drafts") @query.field("load_drafts")
@ -119,9 +124,10 @@ async def unpublish_draft(_, info, draft_id: int):
@login_required @login_required
async def publish_shout(_, info, shout_id: int, draft=None): async def publish_shout(_, info, shout_id: int, draft=None):
"""Publish draft as a shout or update existing shout. """Publish draft as a shout or update existing shout.
Args: Args:
session: SQLAlchemy session to use for database operations shout_id: ID существующей публикации или 0 для новой
draft: Объект черновика (опционально)
""" """
user_id = info.context.get("user_id") user_id = info.context.get("user_id")
author_dict = info.context.get("author", {}) author_dict = info.context.get("author", {})
@ -130,16 +136,25 @@ async def publish_shout(_, info, shout_id: int, draft=None):
return {"error": "User ID and author ID are required"} return {"error": "User ID and author ID are required"}
try: try:
# Use proper SQLAlchemy query
with local_session() as session: with local_session() as session:
# Находим черновик если не передан
if not draft: if not draft:
find_draft_stmt = select(Draft).where(Draft.shout == shout_id) find_draft_stmt = select(Draft).where(Draft.shout == shout_id)
draft = session.execute(find_draft_stmt).scalar_one_or_none() draft = session.execute(find_draft_stmt).scalar_one_or_none()
if not draft:
return {"error": "Draft not found"}
now = int(time.time()) now = int(time.time())
# Находим существующую публикацию или создаем новую
shout = None
was_published = False
if shout_id:
shout = session.query(Shout).filter(Shout.id == shout_id).first()
was_published = shout and shout.published_at is not None
if not shout: if not shout:
# Create new shout from draft # Создаем новую публикацию
shout = Shout( shout = Shout(
body=draft.body, body=draft.body,
slug=draft.slug, slug=draft.slug,
@ -155,15 +170,11 @@ async def publish_shout(_, info, shout_id: int, draft=None):
seo=draft.seo, seo=draft.seo,
created_by=author_id, created_by=author_id,
community=draft.community, community=draft.community,
authors=draft.authors.copy(), # Create copies of relationships
topics=draft.topics.copy(),
draft=draft.id, draft=draft.id,
deleted_at=None, deleted_at=None,
) )
else: else:
# Update existing shout # Обновляем существующую публикацию
shout.authors = draft.authors.copy()
shout.topics = draft.topics.copy()
shout.draft = draft.id shout.draft = draft.id
shout.created_by = author_id shout.created_by = author_id
shout.title = draft.title shout.title = draft.title
@ -178,24 +189,78 @@ async def publish_shout(_, info, shout_id: int, draft=None):
shout.lang = draft.lang shout.lang = draft.lang
shout.seo = draft.seo shout.seo = draft.seo
# Обновляем временные метки
shout.updated_at = now shout.updated_at = now
shout.published_at = now
# Устанавливаем published_at только если это новая публикация
# или публикация была ранее снята с публикации
if not was_published:
shout.published_at = now
draft.updated_at = now draft.updated_at = now
draft.published_at = now draft.published_at = now
# Обрабатываем связи с авторами
if not session.query(ShoutAuthor).filter(
and_(ShoutAuthor.shout == shout.id, ShoutAuthor.author == author_id)
).first():
sa = ShoutAuthor(shout=shout.id, author=author_id)
session.add(sa)
# Обрабатываем темы
if draft.topics:
for topic in draft.topics:
st = ShoutTopic(
topic=topic.id,
shout=shout.id,
main=topic.main if hasattr(topic, 'main') else False
)
session.add(st)
session.add(shout) session.add(shout)
session.add(draft) session.add(draft)
session.flush()
# Инвалидируем кэш только если это новая публикация или была снята с публикации
if not was_published:
cache_keys = [
"feed",
f"author_{author_id}",
"random_top",
"unrated"
]
# Добавляем ключи для тем
for topic in shout.topics:
cache_keys.append(f"topic_{topic.id}")
cache_keys.append(f"topic_shouts_{topic.id}")
await cache_by_id(Topic, topic.id, cache_topic)
# Инвалидируем кэш
await invalidate_shouts_cache(cache_keys)
await invalidate_shout_related_cache(shout, author_id)
# Обновляем кэш авторов
for author in shout.authors:
await cache_by_id(Author, author.id, cache_author)
# Отправляем уведомление о публикации
await notify_shout(shout.dict(), "published")
# Обновляем поисковый индекс
search_service.index(shout)
else:
# Для уже опубликованных материалов просто отправляем уведомление об обновлении
await notify_shout(shout.dict(), "update")
session.commit() session.commit()
return {"shout": shout}
invalidate_shout_related_cache(shout)
invalidate_shouts_cache()
return {"shout": shout}
except Exception as e: except Exception as e:
import traceback logger.error(f"Failed to publish shout: {e}", exc_info=True)
if 'session' in locals():
logger.error(f"Failed to publish shout: {e}") session.rollback()
logger.error(traceback.format_exc()) return {"error": f"Failed to publish shout: {str(e)}"}
session.rollback()
return {"error": "Failed to publish shout"}
@mutation.field("unpublish_shout") @mutation.field("unpublish_shout")

View File

@ -188,6 +188,8 @@ type Topic {
type CommonResult { type CommonResult {
error: String error: String
drafts: [Draft]
draft: Draft
slugs: [String] slugs: [String]
shout: Shout shout: Shout
shouts: [Shout] shouts: [Shout]

View File

@ -7,8 +7,7 @@ from typing import Any, Callable, Dict, TypeVar
import sqlalchemy import sqlalchemy
from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, configure_mappers, declarative_base
from sqlalchemy.orm import Session, configure_mappers
from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import Table
from settings import DB_URL from settings import DB_URL

179
services/pretopic.py Normal file
View File

@ -0,0 +1,179 @@
import concurrent.futures
from typing import Dict, Tuple, List
from txtai.embeddings import Embeddings
from services.logger import root_logger as logger
class TopicClassifier:
def __init__(self, shouts_by_topic: Dict[str, str], publications: List[Dict[str, str]]):
"""
Инициализация классификатора тем и поиска публикаций.
Args:
shouts_by_topic: Словарь {тема: текст_всех_публикаций}
publications: Список публикаций с полями 'id', 'title', 'text'
"""
self.shouts_by_topic = shouts_by_topic
self.topics = list(shouts_by_topic.keys())
self.publications = publications
self.topic_embeddings = None # Для классификации тем
self.search_embeddings = None # Для поиска публикаций
self._initialization_future = None
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def initialize(self) -> None:
"""
Асинхронная инициализация векторных представлений.
"""
if self._initialization_future is None:
self._initialization_future = self._executor.submit(self._prepare_embeddings)
logger.info("Векторизация текстов начата в фоновом режиме...")
def _prepare_embeddings(self) -> None:
"""
Подготавливает векторные представления для тем и поиска.
"""
logger.info("Начинается подготовка векторных представлений...")
# Модель для русского языка
# TODO: model local caching
model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
# Инициализируем embeddings для классификации тем
self.topic_embeddings = Embeddings(path=model_path)
topic_documents = [
(topic, text)
for topic, text in self.shouts_by_topic.items()
]
self.topic_embeddings.index(topic_documents)
# Инициализируем embeddings для поиска публикаций
self.search_embeddings = Embeddings(path=model_path)
search_documents = [
(str(pub['id']), f"{pub['title']} {pub['text']}")
for pub in self.publications
]
self.search_embeddings.index(search_documents)
logger.info("Подготовка векторных представлений завершена.")
def predict_topic(self, text: str) -> Tuple[float, str]:
"""
Предсказывает тему для заданного текста из известного набора тем.
Args:
text: Текст для классификации
Returns:
Tuple[float, str]: (уверенность, тема)
"""
if not self.is_ready():
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
return 0.0, "unknown"
try:
# Ищем наиболее похожую тему
results = self.topic_embeddings.search(text, 1)
if not results:
return 0.0, "unknown"
score, topic = results[0]
return float(score), topic
except Exception as e:
logger.error(f"Ошибка при определении темы: {str(e)}")
return 0.0, "unknown"
def search_similar(self, query: str, limit: int = 5) -> List[Dict[str, any]]:
"""
Ищет публикации похожие на поисковый запрос.
Args:
query: Поисковый запрос
limit: Максимальное количество результатов
Returns:
List[Dict]: Список найденных публикаций с оценкой релевантности
"""
if not self.is_ready():
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
return []
try:
# Ищем похожие публикации
results = self.search_embeddings.search(query, limit)
# Формируем результаты
found_publications = []
for score, pub_id in results:
# Находим публикацию по id
publication = next(
(pub for pub in self.publications if str(pub['id']) == pub_id),
None
)
if publication:
found_publications.append({
**publication,
'relevance': float(score)
})
return found_publications
except Exception as e:
logger.error(f"Ошибка при поиске публикаций: {str(e)}")
return []
def is_ready(self) -> bool:
"""
Проверяет, готовы ли векторные представления.
"""
return self.topic_embeddings is not None and self.search_embeddings is not None
def wait_until_ready(self) -> None:
"""
Ожидает завершения подготовки векторных представлений.
"""
if self._initialization_future:
self._initialization_future.result()
def __del__(self):
"""
Очистка ресурсов при удалении объекта.
"""
if self._executor:
self._executor.shutdown(wait=False)
# Пример использования:
"""
shouts_by_topic = {
"Спорт": "... большой текст со всеми спортивными публикациями ...",
"Технологии": "... большой текст со всеми технологическими публикациями ...",
"Политика": "... большой текст со всеми политическими публикациями ..."
}
publications = [
{
'id': 1,
'title': 'Новый процессор AMD',
'text': 'Компания AMD представила новый процессор...'
},
{
'id': 2,
'title': 'Футбольный матч',
'text': 'Вчера состоялся решающий матч...'
}
]
# Создание классификатора
classifier = TopicClassifier(shouts_by_topic, publications)
classifier.initialize()
classifier.wait_until_ready()
# Определение темы текста
text = "Новый процессор показал высокую производительность"
score, topic = classifier.predict_topic(text)
print(f"Тема: {topic} (уверенность: {score:.4f})")
# Поиск похожих публикаций
query = "процессор AMD производительность"
similar_publications = classifier.search_similar(query, limit=3)
for pub in similar_publications:
print(f"\nНайдена публикация (релевантность: {pub['relevance']:.4f}):")
print(f"Заголовок: {pub['title']}")
print(f"Текст: {pub['text'][:100]}...")
"""

View File

@ -5,7 +5,7 @@ PORT = 8000
DB_URL = ( DB_URL = (
environ.get("DATABASE_URL", "").replace("postgres://", "postgresql://") environ.get("DATABASE_URL", "").replace("postgres://", "postgresql://")
or environ.get("DB_URL", "").replace("postgres://", "postgresql://") or environ.get("DB_URL", "").replace("postgres://", "postgresql://")
or "sqlite:///discoursio-db.sqlite3" or "sqlite:///discoursio.db"
) )
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1" REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1"
AUTH_URL = environ.get("AUTH_URL") or "" AUTH_URL = environ.get("AUTH_URL") or ""
@ -15,3 +15,9 @@ MODE = "development" if "dev" in sys.argv else "production"
ADMIN_SECRET = environ.get("AUTH_SECRET") or "nothing" ADMIN_SECRET = environ.get("AUTH_SECRET") or "nothing"
WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else" WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else"
# own auth
ONETIME_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 3 # 3 days
SESSION_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 30 # 30 days
JWT_ALGORITHM = "HS256"
JWT_SECRET_KEY = environ.get("JWT_SECRET") or "nothing-else-jwt-secret-matters"

55
tests/conftest.py Normal file
View File

@ -0,0 +1,55 @@
import asyncio
import os
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from starlette.testclient import TestClient
from main import app
from services.db import Base
from services.redis import redis
from settings import DB_URL
# Use SQLite for testing
TEST_DB_URL = "sqlite:///test.db"
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def test_engine():
"""Create a test database engine."""
engine = create_engine(TEST_DB_URL)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
os.remove("test.db")
@pytest.fixture
def db_session(test_engine):
"""Create a new database session for a test."""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture
async def redis_client():
"""Create a test Redis client."""
await redis.connect()
yield redis
await redis.disconnect()
@pytest.fixture
def test_client():
"""Create a TestClient instance."""
return TestClient(app)

95
tests/test_drafts.py Normal file
View File

@ -0,0 +1,95 @@
import pytest
from orm.shout import Shout
from orm.author import Author
@pytest.fixture
def test_author(db_session):
"""Create a test author."""
author = Author(
name="Test Author",
slug="test-author",
user="test-user-id"
)
db_session.add(author)
db_session.commit()
return author
@pytest.fixture
def test_shout(db_session):
"""Create test shout with required fields."""
author = Author(name="Test Author", slug="test-author", user="test-user-id")
db_session.add(author)
db_session.flush()
shout = Shout(
title="Test Shout",
slug="test-shout",
created_by=author.id, # Обязательное поле
body="Test body",
layout="article",
lang="ru"
)
db_session.add(shout)
db_session.commit()
return shout
@pytest.mark.asyncio
async def test_create_shout(test_client, db_session, test_author):
"""Test creating a new shout."""
response = test_client.post(
"/",
json={
"query": """
mutation CreateDraft($input: DraftInput!) {
create_draft(input: $input) {
error
draft {
id
title
body
}
}
}
""",
"variables": {
"input": {
"title": "Test Shout",
"body": "This is a test shout",
}
}
}
)
assert response.status_code == 200
data = response.json()
assert "errors" not in data
assert data["data"]["create_draft"]["draft"]["title"] == "Test Shout"
@pytest.mark.asyncio
async def test_load_drafts(test_client, db_session):
"""Test retrieving a shout."""
response = test_client.post(
"/",
json={
"query": """
query {
load_drafts {
error
drafts {
id
title
body
}
}
}
""",
"variables": {
"slug": "test-shout"
}
}
)
assert response.status_code == 200
data = response.json()
assert "errors" not in data
assert data["data"]["load_drafts"]["drafts"] == []

64
tests/test_reactions.py Normal file
View File

@ -0,0 +1,64 @@
import pytest
from orm.reaction import Reaction, ReactionKind
from orm.shout import Shout
from orm.author import Author
from datetime import datetime
@pytest.fixture
def test_setup(db_session):
"""Set up test data."""
now = int(datetime.now().timestamp())
author = Author(name="Test Author", slug="test-author", user="test-user-id")
db_session.add(author)
db_session.flush()
shout = Shout(
title="Test Shout",
slug="test-shout",
created_by=author.id,
body="This is a test shout",
layout="article",
lang="ru",
community=1,
created_at=now,
updated_at=now
)
db_session.add_all([author, shout])
db_session.commit()
return {"author": author, "shout": shout}
@pytest.mark.asyncio
async def test_create_reaction(test_client, db_session, test_setup):
"""Test creating a reaction on a shout."""
response = test_client.post(
"/",
json={
"query": """
mutation CreateReaction($reaction: ReactionInput!) {
create_reaction(reaction: $reaction) {
error
reaction {
id
kind
body
created_by {
name
}
}
}
}
""",
"variables": {
"reaction": {
"shout": test_setup["shout"].id,
"kind": ReactionKind.LIKE.value,
"body": "Great post!"
}
}
}
)
assert response.status_code == 200
data = response.json()
assert "error" not in data
assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value

83
tests/test_shouts.py Normal file
View File

@ -0,0 +1,83 @@
import pytest
from orm.author import Author
from orm.shout import Shout
from datetime import datetime
@pytest.fixture
def test_shout(db_session):
"""Create test shout with required fields."""
now = int(datetime.now().timestamp())
author = Author(name="Test Author", slug="test-author", user="test-user-id")
db_session.add(author)
db_session.flush()
now = int(datetime.now().timestamp())
shout = Shout(
title="Test Shout",
slug="test-shout",
created_by=author.id,
body="Test body",
layout="article",
lang="ru",
community=1,
created_at=now,
updated_at=now
)
db_session.add(shout)
db_session.commit()
return shout
@pytest.mark.asyncio
async def test_get_shout(test_client, db_session):
"""Test retrieving a shout."""
# Создаем автора
author = Author(name="Test Author", slug="test-author", user="test-user-id")
db_session.add(author)
db_session.flush()
now = int(datetime.now().timestamp())
# Создаем публикацию со всеми обязательными полями
shout = Shout(
title="Test Shout",
body="This is a test shout",
slug="test-shout",
created_by=author.id,
layout="article",
lang="ru",
community=1,
created_at=now,
updated_at=now
)
db_session.add(shout)
db_session.commit()
response = test_client.post(
"/",
json={
"query": """
query GetShout($slug: String!) {
get_shout(slug: $slug) {
id
title
body
created_at
updated_at
created_by {
id
name
slug
}
}
}
""",
"variables": {
"slug": "test-shout"
}
}
)
data = response.json()
assert response.status_code == 200
assert "errors" not in data
assert data["data"]["get_shout"]["title"] == "Test Shout"

101
tests/test_validations.py Normal file
View File

@ -0,0 +1,101 @@
import pytest
from datetime import datetime, timedelta
from pydantic import ValidationError
from auth.validations import (
AuthInput,
UserRegistrationInput,
UserLoginInput,
TokenPayload,
OAuthInput,
AuthResponse
)
class TestAuthValidations:
def test_auth_input(self):
"""Test basic auth input validation"""
# Valid case
auth = AuthInput(
user_id="123",
username="testuser",
token="1234567890abcdef1234567890abcdef"
)
assert auth.user_id == "123"
assert auth.username == "testuser"
# Invalid cases
with pytest.raises(ValidationError):
AuthInput(user_id="", username="test", token="x" * 32)
with pytest.raises(ValidationError):
AuthInput(user_id="123", username="t", token="x" * 32)
def test_user_registration(self):
"""Test user registration validation"""
# Valid case
user = UserRegistrationInput(
email="test@example.com",
password="SecurePass123!",
name="Test User"
)
assert user.email == "test@example.com"
assert user.name == "Test User"
# Test email validation
with pytest.raises(ValidationError) as exc:
UserRegistrationInput(
email="invalid-email",
password="SecurePass123!",
name="Test"
)
assert "Invalid email format" in str(exc.value)
# Test password validation
with pytest.raises(ValidationError) as exc:
UserRegistrationInput(
email="test@example.com",
password="weak",
name="Test"
)
assert "String should have at least 8 characters" in str(exc.value)
def test_token_payload(self):
"""Test token payload validation"""
now = datetime.utcnow()
exp = now + timedelta(hours=1)
payload = TokenPayload(
user_id="123",
username="testuser",
exp=exp,
iat=now
)
assert payload.user_id == "123"
assert payload.username == "testuser"
assert payload.scopes == [] # Default empty list
def test_auth_response(self):
"""Test auth response validation"""
# Success case
success_resp = AuthResponse(
success=True,
token="valid_token",
user={"id": "123", "name": "Test"}
)
assert success_resp.success is True
assert success_resp.token == "valid_token"
# Error case
error_resp = AuthResponse(
success=False,
error="Invalid credentials"
)
assert error_resp.success is False
assert error_resp.error == "Invalid credentials"
# Invalid case - отсутствует обязательное поле token при success=True
with pytest.raises(ValidationError):
AuthResponse(
success=True,
user={"id": "123", "name": "Test"}
)