diff --git a/auth/__init__.py b/auth/__init__.py index ab33703d..71f3ac08 100644 --- a/auth/__init__.py +++ b/auth/__init__.py @@ -2,25 +2,25 @@ from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse from starlette.routing import Route -from auth.sessions import SessionManager from auth.internal import verify_internal_auth from auth.orm import Author +from auth.sessions import SessionManager from services.db import local_session -from utils.logger import root_logger as logger from settings import ( - SESSION_COOKIE_NAME, SESSION_COOKIE_HTTPONLY, - SESSION_COOKIE_SECURE, - SESSION_COOKIE_SAMESITE, SESSION_COOKIE_MAX_AGE, + SESSION_COOKIE_NAME, + SESSION_COOKIE_SAMESITE, + SESSION_COOKIE_SECURE, SESSION_TOKEN_HEADER, ) +from utils.logger import root_logger as logger async def logout(request: Request): """ Выход из системы с удалением сессии и cookie. - + Поддерживает получение токена из: 1. HTTP-only cookie 2. Заголовка Authorization @@ -30,7 +30,7 @@ async def logout(request: Request): if SESSION_COOKIE_NAME in request.cookies: token = request.cookies.get(SESSION_COOKIE_NAME) logger.debug(f"[auth] logout: Получен токен из cookie {SESSION_COOKIE_NAME}") - + # Если токен не найден в cookie, проверяем заголовок if not token: # Сначала проверяем основной заголовок авторизации @@ -42,7 +42,7 @@ async def logout(request: Request): else: token = auth_header.strip() logger.debug(f"[auth] logout: Получен прямой токен из заголовка {SESSION_TOKEN_HEADER}") - + # Если токен не найден в основном заголовке, проверяем стандартный Authorization if not token and "Authorization" in request.headers: auth_header = request.headers.get("Authorization") @@ -74,7 +74,7 @@ async def logout(request: Request): key=SESSION_COOKIE_NAME, secure=SESSION_COOKIE_SECURE, httponly=SESSION_COOKIE_HTTPONLY, - samesite=SESSION_COOKIE_SAMESITE + samesite=SESSION_COOKIE_SAMESITE, ) logger.info("[auth] logout: Cookie успешно удалена") @@ -84,22 +84,22 @@ async def logout(request: Request): async def refresh_token(request: Request): """ Обновление токена аутентификации. - + Поддерживает получение токена из: 1. HTTP-only cookie 2. Заголовка Authorization - + Возвращает новый токен как в HTTP-only cookie, так и в теле ответа. """ token = None source = None - + # Получаем текущий токен из cookie if SESSION_COOKIE_NAME in request.cookies: token = request.cookies.get(SESSION_COOKIE_NAME) source = "cookie" logger.debug(f"[auth] refresh_token: Токен получен из cookie {SESSION_COOKIE_NAME}") - + # Если токен не найден в cookie, проверяем заголовок авторизации if not token: # Проверяем основной заголовок авторизации @@ -113,7 +113,7 @@ async def refresh_token(request: Request): token = auth_header.strip() source = "header" logger.debug(f"[auth] refresh_token: Токен получен из заголовка {SESSION_TOKEN_HEADER} (прямой)") - + # Если токен не найден в основном заголовке, проверяем стандартный Authorization if not token and "Authorization" in request.headers: auth_header = request.headers.get("Authorization") @@ -147,9 +147,7 @@ async def refresh_token(request: Request): if not new_token: logger.error(f"[auth] refresh_token: Не удалось обновить токен для пользователя {user_id}") - return JSONResponse( - {"success": False, "error": "Не удалось обновить токен"}, status_code=500 - ) + return JSONResponse({"success": False, "error": "Не удалось обновить токен"}, status_code=500) # Создаем ответ response = JSONResponse( diff --git a/auth/credentials.py b/auth/credentials.py index 2e8e5a7b..0391ac5c 100644 --- a/auth/credentials.py +++ b/auth/credentials.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Any +from typing import Any, Dict, List, Optional, Set from pydantic import BaseModel, Field diff --git a/auth/decorators.py b/auth/decorators.py index 768a314c..47476148 100644 --- a/auth/decorators.py +++ b/auth/decorators.py @@ -1,19 +1,21 @@ from functools import wraps -from typing import Callable, Any, Dict, Optional +from typing import Any, Callable, Dict, Optional + from graphql import GraphQLError, GraphQLResolveInfo from sqlalchemy import exc from auth.credentials import AuthCredentials -from services.db import local_session -from auth.orm import Author from auth.exceptions import OperationNotAllowed -from utils.logger import root_logger as logger -from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST, SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME -from auth.sessions import SessionManager -from auth.jwtcodec import JWTCodec, InvalidToken, ExpiredToken -from auth.tokenstorage import TokenStorage -from services.redis import redis from auth.internal import authenticate +from auth.jwtcodec import ExpiredToken, InvalidToken, JWTCodec +from auth.orm import Author +from auth.sessions import SessionManager +from auth.tokenstorage import TokenStorage +from services.db import local_session +from services.redis import redis +from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST +from settings import SESSION_COOKIE_NAME, SESSION_TOKEN_HEADER +from utils.logger import root_logger as logger ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") @@ -21,10 +23,10 @@ ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") def get_safe_headers(request: Any) -> Dict[str, str]: """ Безопасно получает заголовки запроса. - + Args: request: Объект запроса - + Returns: Dict[str, str]: Словарь заголовков """ @@ -34,12 +36,9 @@ def get_safe_headers(request: Any) -> Dict[str, str]: if hasattr(request, "scope") and isinstance(request.scope, dict): scope_headers = request.scope.get("headers", []) if scope_headers: - headers.update({ - k.decode("utf-8").lower(): v.decode("utf-8") - for k, v in scope_headers - }) + headers.update({k.decode("utf-8").lower(): v.decode("utf-8") for k, v in scope_headers}) logger.debug(f"[decorators] Получены заголовки из request.scope: {len(headers)}") - + # Второй приоритет: метод headers() или атрибут headers if hasattr(request, "headers"): if callable(request.headers): @@ -55,15 +54,15 @@ def get_safe_headers(request: Any) -> Dict[str, str]: elif isinstance(h, dict): headers.update({k.lower(): v for k, v in h.items()}) logger.debug(f"[decorators] Получены заголовки из request.headers словаря: {len(headers)}") - + # Третий приоритет: атрибут _headers if hasattr(request, "_headers") and request._headers: headers.update({k.lower(): v for k, v in request._headers.items()}) logger.debug(f"[decorators] Получены заголовки из request._headers: {len(headers)}") - + except Exception as e: logger.warning(f"[decorators] Ошибка при доступе к заголовкам: {e}") - + return headers @@ -72,13 +71,13 @@ def get_auth_token(request: Any) -> Optional[str]: Извлекает токен авторизации из запроса. Порядок проверки: 1. Проверяет auth из middleware - 2. Проверяет auth из scope + 2. Проверяет auth из scope 3. Проверяет заголовок Authorization 4. Проверяет cookie с именем auth_token - + Args: request: Объект запроса - + Returns: Optional[str]: Токен авторизации или None """ @@ -100,7 +99,7 @@ def get_auth_token(request: Any) -> Optional[str]: # 3. Проверяем заголовок Authorization headers = get_safe_headers(request) - + # Сначала проверяем основной заголовок авторизации auth_header = headers.get(SESSION_TOKEN_HEADER.lower(), "") if auth_header: @@ -112,7 +111,7 @@ def get_auth_token(request: Any) -> Optional[str]: token = auth_header.strip() logger.debug(f"[decorators] Прямой токен получен из заголовка {SESSION_TOKEN_HEADER}: {len(token)}") return token - + # Затем проверяем стандартный заголовок Authorization, если основной не определен if SESSION_TOKEN_HEADER.lower() != "authorization": auth_header = headers.get("authorization", "") @@ -139,10 +138,10 @@ def get_auth_token(request: Any) -> Optional[str]: async def validate_graphql_context(info: Any) -> None: """ Проверяет валидность GraphQL контекста и проверяет авторизацию. - + Args: info: GraphQL информация о контексте - + Raises: GraphQLError: если контекст невалиден или пользователь не авторизован """ @@ -161,7 +160,7 @@ async def validate_graphql_context(info: Any) -> None: if auth and auth.logged_in: logger.debug(f"[decorators] Пользователь уже авторизован: {auth.author_id}") return - + # Если аутентификации нет в request.auth, пробуем получить ее из scope if hasattr(request, "scope") and "auth" in request.scope: auth_cred = request.scope.get("auth") @@ -170,49 +169,45 @@ async def validate_graphql_context(info: Any) -> None: # Устанавливаем auth в request для дальнейшего использования request.auth = auth_cred return - + # Если авторизации нет ни в auth, ни в scope, пробуем получить и проверить токен token = get_auth_token(request) if not token: # Если токен не найден, возвращаем ошибку авторизации client_info = { "ip": getattr(request.client, "host", "unknown") if hasattr(request, "client") else "unknown", - "headers": get_safe_headers(request) + "headers": get_safe_headers(request), } logger.warning(f"[decorators] Токен авторизации не найден: {client_info}") raise GraphQLError("Unauthorized - please login") - + # Используем единый механизм проверки токена из auth.internal auth_state = await authenticate(request) - + if not auth_state.logged_in: error_msg = auth_state.error or "Invalid or expired token" logger.warning(f"[decorators] Недействительный токен: {error_msg}") raise GraphQLError(f"Unauthorized - {error_msg}") - + # Если все проверки пройдены, создаем AuthCredentials и устанавливаем в request.auth with local_session() as session: try: author = session.query(Author).filter(Author.id == auth_state.author_id).one() # Получаем разрешения из ролей scopes = author.get_permissions() - + # Создаем объект авторизации auth_cred = AuthCredentials( - author_id=author.id, - scopes=scopes, - logged_in=True, - email=author.email, - token=auth_state.token + author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=auth_state.token ) - + # Устанавливаем auth в request request.auth = auth_cred logger.debug(f"[decorators] Токен успешно проверен и установлен для пользователя {auth_state.author_id}") except exc.NoResultFound: logger.error(f"[decorators] Пользователь с ID {auth_state.author_id} не найден в базе данных") raise GraphQLError("Unauthorized - user not found") - + return @@ -229,18 +224,19 @@ def admin_auth_required(resolver: Callable) -> Callable: Raises: GraphQLError: если пользователь не авторизован или не имеет доступа администратора - + Example: >>> @admin_auth_required ... async def admin_resolver(root, info, **kwargs): ... return "Admin data" """ + @wraps(resolver) async def wrapper(root: Any = None, info: Any = None, **kwargs): try: # Проверяем авторизацию пользователя await validate_graphql_context(info) - + # Получаем объект авторизации auth = info.context["request"].auth if not auth or not auth.logged_in: @@ -255,22 +251,24 @@ def admin_auth_required(resolver: Callable) -> Callable: if not author_id: logger.error(f"[admin_auth_required] ID автора не определен: {auth}") raise GraphQLError("Unauthorized - invalid user ID") - + author = session.query(Author).filter(Author.id == author_id).one() - + # Проверяем, является ли пользователь администратором if author.email in ADMIN_EMAILS: logger.info(f"Admin access granted for {author.email} (ID: {author.id})") return await resolver(root, info, **kwargs) - + # Проверяем роли пользователя - admin_roles = ['admin', 'super'] + admin_roles = ["admin", "super"] user_roles = [role.id for role in author.roles] if author.roles else [] - + if any(role in admin_roles for role in user_roles): - logger.info(f"Admin access granted for {author.email} (ID: {author.id}) with role: {user_roles}") + logger.info( + f"Admin access granted for {author.email} (ID: {author.id}) with role: {user_roles}" + ) return await resolver(root, info, **kwargs) - + logger.warning(f"Admin access denied for {author.email} (ID: {author.id}). Roles: {user_roles}") raise GraphQLError("Unauthorized - not an admin") except exc.NoResultFound: @@ -301,7 +299,7 @@ def permission_required(resource: str, operation: str, func): async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): # Сначала проверяем авторизацию await validate_graphql_context(info) - + # Получаем объект авторизации logger.debug(f"[permission_required] Контекст: {info.context}") auth = info.context["request"].auth @@ -324,21 +322,27 @@ def permission_required(resource: str, operation: str, func): if author.email in ADMIN_EMAILS: logger.debug(f"[permission_required] Администратор {author.email} имеет все разрешения") return await func(parent, info, *args, **kwargs) - + # Проверяем роли пользователя - admin_roles = ['admin', 'super'] + admin_roles = ["admin", "super"] user_roles = [role.id for role in author.roles] if author.roles else [] - + if any(role in admin_roles for role in user_roles): - logger.debug(f"[permission_required] Пользователь с ролью администратора {author.email} имеет все разрешения") + logger.debug( + f"[permission_required] Пользователь с ролью администратора {author.email} имеет все разрешения" + ) return await func(parent, info, *args, **kwargs) # Проверяем разрешение if not author.has_permission(resource, operation): - logger.warning(f"[permission_required] У пользователя {author.email} нет разрешения {operation} на {resource}") + logger.warning( + f"[permission_required] У пользователя {author.email} нет разрешения {operation} на {resource}" + ) raise OperationNotAllowed(f"No permission for {operation} on {resource}") - - logger.debug(f"[permission_required] Пользователь {author.email} имеет разрешение {operation} на {resource}") + + logger.debug( + f"[permission_required] Пользователь {author.email} имеет разрешение {operation} на {resource}" + ) return await func(parent, info, *args, **kwargs) except exc.NoResultFound: logger.error(f"[permission_required] Пользователь с ID {auth.author_id} не найден в базе данных") @@ -349,14 +353,15 @@ def permission_required(resource: str, operation: str, func): def login_accepted(func): """ - Декоратор для резолверов, которые могут работать как с авторизованными, + Декоратор для резолверов, которые могут работать как с авторизованными, так и с неавторизованными пользователями. - + Добавляет информацию о пользователе в контекст, если пользователь авторизован. - + Args: func: Декорируемая функция """ + @wraps(func) async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): try: @@ -366,10 +371,10 @@ def login_accepted(func): except GraphQLError: # Игнорируем ошибку авторизации pass - + # Получаем объект авторизации auth = getattr(info.context["request"], "auth", None) - + if auth and auth.logged_in: # Если пользователь авторизован, добавляем информацию о нем в контекст with local_session() as session: diff --git a/auth/handler.py b/auth/handler.py index 813655d1..f2677ab3 100644 --- a/auth/handler.py +++ b/auth/handler.py @@ -1,46 +1,48 @@ from ariadne.asgi.handlers import GraphQLHTTPHandler from starlette.requests import Request -from starlette.responses import Response, JSONResponse +from starlette.responses import JSONResponse, Response + from auth.middleware import auth_middleware from utils.logger import root_logger as logger + class EnhancedGraphQLHTTPHandler(GraphQLHTTPHandler): """ Улучшенный GraphQL HTTP обработчик с поддержкой cookie и авторизации. - + Расширяет стандартный GraphQLHTTPHandler для: 1. Создания расширенного контекста запроса с авторизационными данными 2. Корректной обработки ответов с cookie и headers 3. Интеграции с AuthMiddleware """ - + async def get_context_for_request(self, request: Request, data: dict) -> dict: """ Расширяем контекст для GraphQL запросов. - + Добавляет к стандартному контексту: - Объект response для установки cookie - Интеграцию с AuthMiddleware - Расширения для управления авторизацией - + Args: request: Starlette Request объект data: данные запроса - + Returns: dict: контекст с дополнительными данными для авторизации и cookie """ # Получаем стандартный контекст от базового класса context = await super().get_context_for_request(request, data) - + # Создаем объект ответа для установки cookie response = JSONResponse({}) context["response"] = response - + # Интегрируем с AuthMiddleware auth_middleware.set_context(context) context["extensions"] = auth_middleware - + # Добавляем данные авторизации только если они доступны # Без проверки hasattr, так как это вызывает ошибку до обработки AuthenticationMiddleware if hasattr(request, "auth") and request.auth: @@ -48,7 +50,7 @@ class EnhancedGraphQLHTTPHandler(GraphQLHTTPHandler): context["auth"] = request.auth # Безопасно логируем информацию о типе объекта auth logger.debug(f"[graphql] Добавлены данные авторизации в контекст: {type(request.auth).__name__}") - + logger.debug(f"[graphql] Подготовлен расширенный контекст для запроса") - + return context diff --git a/auth/identity.py b/auth/identity.py index 4222c3d2..be32fbb2 100644 --- a/auth/identity.py +++ b/auth/identity.py @@ -1,13 +1,12 @@ from binascii import hexlify from hashlib import sha256 -from typing import Any, Dict, TypeVar, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, TypeVar from passlib.hash import bcrypt -from auth.exceptions import ExpiredToken, InvalidToken, InvalidPassword +from auth.exceptions import ExpiredToken, InvalidPassword, InvalidToken from auth.jwtcodec import JWTCodec from auth.tokenstorage import TokenStorage - from services.db import local_session # Для типизации @@ -86,9 +85,7 @@ class Identity: # Проверим исходный пароль в orm_author if not orm_author.password: - logger.warning( - f"[auth.identity] Пароль в исходном объекте автора пуст: email={orm_author.email}" - ) + logger.warning(f"[auth.identity] Пароль в исходном объекте автора пуст: email={orm_author.email}") raise InvalidPassword("Пароль не установлен для данного пользователя") # Проверяем пароль напрямую, не используя dict() diff --git a/auth/internal.py b/auth/internal.py index 16f1e187..d21b5dca 100644 --- a/auth/internal.py +++ b/auth/internal.py @@ -1,22 +1,22 @@ -from typing import Optional, Tuple import time -from typing import Any +from typing import Any, Optional, Tuple from sqlalchemy.orm import exc from starlette.authentication import AuthenticationBackend, BaseUser, UnauthenticatedUser from starlette.requests import HTTPConnection from auth.credentials import AuthCredentials +from auth.exceptions import ExpiredToken, InvalidToken +from auth.jwtcodec import JWTCodec from auth.orm import Author from auth.sessions import SessionManager -from services.db import local_session -from settings import SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME, ADMIN_EMAILS as ADMIN_EMAILS_LIST -from utils.logger import root_logger as logger -from auth.jwtcodec import JWTCodec -from auth.exceptions import ExpiredToken, InvalidToken from auth.state import AuthState from auth.tokenstorage import TokenStorage +from services.db import local_session from services.redis import redis +from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST +from settings import SESSION_COOKIE_NAME, SESSION_TOKEN_HEADER +from utils.logger import root_logger as logger ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") @@ -24,13 +24,9 @@ ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") class AuthenticatedUser(BaseUser): """Аутентифицированный пользователь для Starlette""" - def __init__(self, - user_id: str, - username: str = "", - roles: list = None, - permissions: dict = None, - token: str = None - ): + def __init__( + self, user_id: str, username: str = "", roles: list = None, permissions: dict = None, token: str = None + ): self.user_id = user_id self.username = username self.roles = roles or [] @@ -56,17 +52,17 @@ class InternalAuthentication(AuthenticationBackend): async def authenticate(self, request: HTTPConnection): """ Аутентифицирует пользователя по токену из заголовка или cookie. - + Порядок поиска токена: 1. Проверяем заголовок SESSION_TOKEN_HEADER (может быть установлен middleware) - 2. Проверяем scope/auth в request, куда middleware мог сохранить токен + 2. Проверяем scope/auth в request, куда middleware мог сохранить токен 3. Проверяем cookie Возвращает: tuple: (AuthCredentials, BaseUser) """ token = None - + # 1. Проверяем заголовок if SESSION_TOKEN_HEADER in request.headers: token_header = request.headers.get(SESSION_TOKEN_HEADER) @@ -77,19 +73,19 @@ class InternalAuthentication(AuthenticationBackend): else: token = token_header.strip() logger.debug(f"[auth.authenticate] Извлечен прямой токен из заголовка {SESSION_TOKEN_HEADER}") - + # 2. Проверяем scope/auth, который мог быть установлен middleware if not token and hasattr(request, "scope") and "auth" in request.scope: auth_data = request.scope.get("auth", {}) if isinstance(auth_data, dict) and "token" in auth_data: token = auth_data["token"] logger.debug(f"[auth.authenticate] Извлечен токен из request.scope['auth']") - + # 3. Проверяем cookie if not token and hasattr(request, "cookies") and SESSION_COOKIE_NAME in request.cookies: token = request.cookies.get(SESSION_COOKIE_NAME) logger.debug(f"[auth.authenticate] Извлечен токен из cookie {SESSION_COOKIE_NAME}") - + # Если токен не найден, возвращаем неаутентифицированного пользователя if not token: logger.debug("[auth.authenticate] Токен не найден") @@ -112,9 +108,7 @@ class InternalAuthentication(AuthenticationBackend): if author.is_locked(): logger.debug(f"[auth.authenticate] Аккаунт заблокирован: {author.id}") - return AuthCredentials( - scopes={}, error_message="Account is locked" - ), UnauthenticatedUser() + return AuthCredentials(scopes={}, error_message="Account is locked"), UnauthenticatedUser() # Получаем разрешения из ролей scopes = author.get_permissions() @@ -128,11 +122,7 @@ class InternalAuthentication(AuthenticationBackend): # Создаем объекты авторизации с сохранением токена credentials = AuthCredentials( - author_id=author.id, - scopes=scopes, - logged_in=True, - email=author.email, - token=token + author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=token ) user = AuthenticatedUser( @@ -140,7 +130,7 @@ class InternalAuthentication(AuthenticationBackend): username=author.slug or author.email or "", roles=roles, permissions=scopes, - token=token + token=token, ) logger.debug(f"[auth.authenticate] Успешная аутентификация: {author.email}") @@ -163,7 +153,7 @@ async def verify_internal_auth(token: str) -> Tuple[str, list, bool]: tuple: (user_id, roles, is_admin) """ logger.debug(f"[verify_internal_auth] Проверка токена: {token[:10]}...") - + # Обработка формата "Bearer " (если токен не был обработан ранее) if token and token.startswith("Bearer "): token = token.replace("Bearer ", "", 1).strip() @@ -188,11 +178,13 @@ async def verify_internal_auth(token: str) -> Tuple[str, list, bool]: # Получаем роли roles = [role.id for role in author.roles] logger.debug(f"[verify_internal_auth] Роли пользователя: {roles}") - + # Определяем, является ли пользователь администратором - is_admin = any(role in ['admin', 'super'] for role in roles) or author.email in ADMIN_EMAILS - logger.debug(f"[verify_internal_auth] Пользователь {author.id} {'является' if is_admin else 'не является'} администратором") - + is_admin = any(role in ["admin", "super"] for role in roles) or author.email in ADMIN_EMAILS + logger.debug( + f"[verify_internal_auth] Пользователь {author.id} {'является' if is_admin else 'не является'} администратором" + ) + return str(author.id), roles, is_admin except exc.NoResultFound: logger.warning(f"[verify_internal_auth] Пользователь с ID {payload.user_id} не найден в БД или не активен") @@ -257,7 +249,7 @@ async def authenticate(request: Any) -> AuthState: headers = dict(request.headers()) else: headers = dict(request.headers) - + auth_header = headers.get(SESSION_TOKEN_HEADER, "") if auth_header and auth_header.startswith("Bearer "): token = auth_header[7:].strip() @@ -285,13 +277,13 @@ async def authenticate(request: Any) -> AuthState: logger.warning(f"[auth.authenticate] Токен не валиден: не найдена сессия") state.error = "Invalid or expired token" return state - + # Создаем успешное состояние авторизации state.logged_in = True state.author_id = payload.user_id state.token = token state.username = payload.username - + # Если запрос имеет атрибут auth, устанавливаем в него авторизационные данные if hasattr(request, "auth") or hasattr(request, "__setattr__"): try: @@ -301,22 +293,20 @@ async def authenticate(request: Any) -> AuthState: if author: # Получаем разрешения из ролей scopes = author.get_permissions() - + # Создаем объект авторизации auth_cred = AuthCredentials( - author_id=author.id, - scopes=scopes, - logged_in=True, - email=author.email, - token=token + author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=token ) - + # Устанавливаем auth в request setattr(request, "auth", auth_cred) - logger.debug(f"[auth.authenticate] Авторизационные данные установлены в request.auth для {payload.user_id}") + logger.debug( + f"[auth.authenticate] Авторизационные данные установлены в request.auth для {payload.user_id}" + ) except Exception as e: logger.error(f"[auth.authenticate] Ошибка при установке auth в request: {e}") - + logger.info(f"[auth.authenticate] Успешная аутентификация пользователя {state.author_id}") - + return state diff --git a/auth/jwtcodec.py b/auth/jwtcodec.py index 1c1612c7..abca8bf7 100644 --- a/auth/jwtcodec.py +++ b/auth/jwtcodec.py @@ -1,12 +1,13 @@ -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone +from typing import Optional import jwt from pydantic import BaseModel -from typing import Optional -from utils.logger import root_logger as logger from auth.exceptions import ExpiredToken, InvalidToken from settings import JWT_ALGORITHM, JWT_SECRET_KEY +from utils.logger import root_logger as logger + class TokenPayload(BaseModel): user_id: str @@ -28,14 +29,14 @@ class JWTCodec: # Для объектов с атрибутами user_id = str(getattr(user, "id", "")) username = getattr(user, "slug", "") or getattr(user, "email", "") or getattr(user, "phone", "") or "" - + logger.debug(f"[JWTCodec.encode] Кодирование токена для user_id={user_id}, username={username}") - + # Если время истечения не указано, установим срок годности на 30 дней if exp is None: exp = datetime.now(tz=timezone.utc) + timedelta(days=30) logger.debug(f"[JWTCodec.encode] Время истечения не указано, устанавливаем срок: {exp}") - + # Важно: убедимся, что exp всегда является либо datetime, либо целым числом от timestamp if isinstance(exp, datetime): # Преобразуем datetime в timestamp чтобы гарантировать правильный формат @@ -44,7 +45,7 @@ class JWTCodec: # Если передано что-то другое, установим значение по умолчанию logger.warning(f"[JWTCodec.encode] Некорректный формат exp: {exp}, используем значение по умолчанию") exp_timestamp = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp()) - + payload = { "user_id": user_id, "username": username, @@ -52,9 +53,9 @@ class JWTCodec: "iat": datetime.now(tz=timezone.utc), "iss": "discours", } - + logger.debug(f"[JWTCodec.encode] Сформирован payload: {payload}") - + try: token = jwt.encode(payload, JWT_SECRET_KEY, JWT_ALGORITHM) logger.debug(f"[JWTCodec.encode] Токен успешно создан, длина: {len(token) if token else 0}") @@ -66,11 +67,11 @@ class JWTCodec: @staticmethod def decode(token: str, verify_exp: bool = True): logger.debug(f"[JWTCodec.decode] Начало декодирования токена длиной {len(token) if token else 0}") - + if not token: logger.error("[JWTCodec.decode] Пустой токен") return None - + try: payload = jwt.decode( token, @@ -83,21 +84,23 @@ class JWTCodec: issuer="discours", ) logger.debug(f"[JWTCodec.decode] Декодирован payload: {payload}") - + # Убедимся, что exp существует (добавим обработку если exp отсутствует) if "exp" not in payload: logger.warning(f"[JWTCodec.decode] В токене отсутствует поле exp") # Добавим exp по умолчанию, чтобы избежать ошибки при создании TokenPayload payload["exp"] = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp()) - + try: r = TokenPayload(**payload) - logger.debug(f"[JWTCodec.decode] Создан объект TokenPayload: user_id={r.user_id}, username={r.username}") + logger.debug( + f"[JWTCodec.decode] Создан объект TokenPayload: user_id={r.user_id}, username={r.username}" + ) return r except Exception as e: logger.error(f"[JWTCodec.decode] Ошибка при создании TokenPayload: {e}") return None - + except jwt.InvalidIssuedAtError: logger.error("[JWTCodec.decode] Недействительное время выпуска токена") return None diff --git a/auth/middleware.py b/auth/middleware.py index e67187ce..49c6b7ab 100644 --- a/auth/middleware.py +++ b/auth/middleware.py @@ -1,19 +1,29 @@ """ Middleware для обработки авторизации в GraphQL запросах """ + from typing import Any, Dict + +from starlette.datastructures import Headers from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.datastructures import Headers -from starlette.types import ASGIApp, Scope, Receive, Send +from starlette.types import ASGIApp, Receive, Scope, Send + +from settings import ( + SESSION_COOKIE_HTTPONLY, + SESSION_COOKIE_MAX_AGE, + SESSION_COOKIE_NAME, + SESSION_COOKIE_SAMESITE, + SESSION_COOKIE_SECURE, + SESSION_TOKEN_HEADER, +) from utils.logger import root_logger as logger -from settings import SESSION_COOKIE_HTTPONLY, SESSION_COOKIE_MAX_AGE, SESSION_COOKIE_SAMESITE, SESSION_COOKIE_SECURE, SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME class AuthMiddleware: """ Универсальный middleware для обработки авторизации и управления cookies. - + Основные функции: 1. Извлечение Bearer токена из заголовка Authorization или cookie 2. Добавление токена в заголовки запроса для обработки AuthenticationMiddleware @@ -23,7 +33,7 @@ class AuthMiddleware: def __init__(self, app: ASGIApp): self.app = app self._context = None - + async def __call__(self, scope: Scope, receive: Receive, send: Send): """Обработка ASGI запроса""" if scope["type"] != "http": @@ -93,33 +103,29 @@ class AuthMiddleware: scope["headers"] = new_headers # Также добавляем информацию о типе аутентификации для дальнейшего использования - scope["auth"] = { - "type": "bearer", - "token": token, - "source": token_source - } + scope["auth"] = {"type": "bearer", "token": token, "source": token_source} logger.debug(f"[middleware] Токен добавлен в scope для аутентификации из источника: {token_source}") else: logger.debug(f"[middleware] Токен не найден ни в заголовке, ни в cookie") await self.app(scope, receive, send) - + def set_context(self, context): """Сохраняет ссылку на контекст GraphQL запроса""" self._context = context logger.debug(f"[middleware] Установлен контекст GraphQL: {bool(context)}") - + def set_cookie(self, key, value, **options): """ Устанавливает cookie в ответе - + Args: key: Имя cookie value: Значение cookie **options: Дополнительные параметры (httponly, secure, max_age, etc.) """ success = False - + # Способ 1: Через response if self._context and "response" in self._context and hasattr(self._context["response"], "set_cookie"): try: @@ -128,7 +134,7 @@ class AuthMiddleware: success = True except Exception as e: logger.error(f"[middleware] Ошибка при установке cookie {key} через response: {str(e)}") - + # Способ 2: Через собственный response в контексте if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "set_cookie"): try: @@ -137,20 +143,20 @@ class AuthMiddleware: success = True except Exception as e: logger.error(f"[middleware] Ошибка при установке cookie {key} через _response: {str(e)}") - + if not success: logger.error(f"[middleware] Не удалось установить cookie {key}: объекты response недоступны") def delete_cookie(self, key, **options): """ Удаляет cookie из ответа - + Args: key: Имя cookie для удаления **options: Дополнительные параметры """ success = False - + # Способ 1: Через response if self._context and "response" in self._context and hasattr(self._context["response"], "delete_cookie"): try: @@ -159,7 +165,7 @@ class AuthMiddleware: success = True except Exception as e: logger.error(f"[middleware] Ошибка при удалении cookie {key} через response: {str(e)}") - + # Способ 2: Через собственный response в контексте if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "delete_cookie"): try: @@ -168,7 +174,7 @@ class AuthMiddleware: success = True except Exception as e: logger.error(f"[middleware] Ошибка при удалении cookie {key} через _response: {str(e)}") - + if not success: logger.error(f"[middleware] Не удалось удалить cookie {key}: объекты response недоступны") @@ -180,38 +186,41 @@ class AuthMiddleware: try: # Получаем доступ к контексту запроса context = info.context - + # Сохраняем ссылку на контекст self.set_context(context) - + # Добавляем себя как объект, содержащий утилитные методы context["extensions"] = self - + # Проверяем наличие response в контексте if "response" not in context or not context["response"]: from starlette.responses import JSONResponse + context["response"] = JSONResponse({}) logger.debug("[middleware] Создан новый response объект в контексте GraphQL") - - logger.debug(f"[middleware] GraphQL resolve: контекст подготовлен, добавлены расширения для работы с cookie") - + + logger.debug( + f"[middleware] GraphQL resolve: контекст подготовлен, добавлены расширения для работы с cookie" + ) + return await next(root, info, *args, **kwargs) except Exception as e: logger.error(f"[AuthMiddleware] Ошибка в GraphQL resolve: {str(e)}") raise - + async def process_result(self, request: Request, result: Any) -> Response: """ Обрабатывает результат GraphQL запроса, поддерживая установку cookie - + Args: request: Starlette Request объект result: результат GraphQL запроса (dict или Response) - + Returns: Response: HTTP-ответ с результатом и cookie (если необходимо) """ - + # Проверяем, является ли result уже объектом Response if isinstance(result, Response): response = result @@ -220,19 +229,20 @@ class AuthMiddleware: if isinstance(result, JSONResponse): try: import json - result_data = json.loads(result.body.decode('utf-8')) + + result_data = json.loads(result.body.decode("utf-8")) except Exception as e: logger.error(f"[process_result] Не удалось извлечь данные из JSONResponse: {str(e)}") else: response = JSONResponse(result) result_data = result - + # Проверяем, был ли токен в запросе или ответе if request.method == "POST": try: data = await request.json() op_name = data.get("operationName", "").lower() - + # Если это операция логина или обновления токена, и в ответе есть токен if op_name in ["login", "refreshtoken"]: token = None @@ -243,32 +253,35 @@ class AuthMiddleware: op_result = data_obj.get(op_name, {}) if isinstance(op_result, dict) and "token" in op_result: token = op_result.get("token") - + if token: # Устанавливаем cookie с токеном response.set_cookie( key=SESSION_COOKIE_NAME, value=token, - httponly=SESSION_COOKIE_HTTPONLY, + httponly=SESSION_COOKIE_HTTPONLY, secure=SESSION_COOKIE_SECURE, samesite=SESSION_COOKIE_SAMESITE, max_age=SESSION_COOKIE_MAX_AGE, ) - logger.debug(f"[graphql_handler] Установлена cookie {SESSION_COOKIE_NAME} для операции {op_name}") - + logger.debug( + f"[graphql_handler] Установлена cookie {SESSION_COOKIE_NAME} для операции {op_name}" + ) + # Если это операция logout, удаляем cookie elif op_name == "logout": response.delete_cookie( key=SESSION_COOKIE_NAME, secure=SESSION_COOKIE_SECURE, httponly=SESSION_COOKIE_HTTPONLY, - samesite=SESSION_COOKIE_SAMESITE + samesite=SESSION_COOKIE_SAMESITE, ) logger.debug(f"[graphql_handler] Удалена cookie {SESSION_COOKIE_NAME} для операции {op_name}") except Exception as e: logger.error(f"[process_result] Ошибка при обработке POST запроса: {str(e)}") - + return response - + + # Создаем единый экземпляр AuthMiddleware для использования с GraphQL -auth_middleware = AuthMiddleware(lambda scope, receive, send: None) \ No newline at end of file +auth_middleware = AuthMiddleware(lambda scope, receive, send: None) diff --git a/auth/oauth.py b/auth/oauth.py index f91e9f96..410557d1 100644 --- a/auth/oauth.py +++ b/auth/oauth.py @@ -1,11 +1,12 @@ +import time +from secrets import token_urlsafe + from authlib.integrations.starlette_client import OAuth from authlib.oauth2.rfc7636 import create_s256_code_challenge -from starlette.responses import RedirectResponse, JSONResponse -from secrets import token_urlsafe -import time +from starlette.responses import JSONResponse, RedirectResponse -from auth.tokenstorage import TokenStorage from auth.orm import Author +from auth.tokenstorage import TokenStorage from services.db import local_session from settings import FRONTEND_URL, OAUTH_CLIENTS @@ -129,9 +130,7 @@ async def oauth_callback(request): return JSONResponse({"error": "Provider not configured"}, status_code=400) # Получаем токен с PKCE verifier - token = await client.authorize_access_token( - request, code_verifier=request.session.get("code_verifier") - ) + token = await client.authorize_access_token(request, code_verifier=request.session.get("code_verifier")) # Получаем профиль пользователя profile = await get_user_profile(provider, client, token) diff --git a/auth/orm.py b/auth/orm.py index 2599c4d8..b812a0d7 100644 --- a/auth/orm.py +++ b/auth/orm.py @@ -1,5 +1,6 @@ import time from typing import Dict, Set + from sqlalchemy import JSON, Boolean, Column, ForeignKey, Index, Integer, String from sqlalchemy.orm import relationship @@ -180,7 +181,7 @@ class Author(Base): # ) # Список защищенных полей, которые видны только владельцу и администраторам - _protected_fields = ['email', 'password', 'provider_access_token', 'provider_refresh_token'] + _protected_fields = ["email", "password", "provider_access_token", "provider_refresh_token"] @property def is_authenticated(self) -> bool: @@ -241,27 +242,27 @@ class Author(Base): def dict(self, access=False) -> Dict: """ Сериализует объект Author в словарь с учетом прав доступа. - + Args: access (bool, optional): Флаг, указывающий, доступны ли защищенные поля - + Returns: dict: Словарь с атрибутами Author, отфильтрованный по правам доступа """ # Получаем все атрибуты объекта result = {c.name: getattr(self, c.name) for c in self.__table__.columns} - + # Добавляем роли как список идентификаторов и названий - if hasattr(self, 'roles'): - result['roles'] = [] + if hasattr(self, "roles"): + result["roles"] = [] for role in self.roles: if isinstance(role, dict): - result['roles'].append(role.get('id')) - + result["roles"].append(role.get("id")) + # скрываем защищенные поля if not access: for field in self._protected_fields: if field in result: result[field] = None - + return result diff --git a/auth/permissions.py b/auth/permissions.py index 5ce227ad..bcb18792 100644 --- a/auth/permissions.py +++ b/auth/permissions.py @@ -9,9 +9,9 @@ from typing import List, Union from sqlalchemy.orm import Session -from auth.orm import Author, Role, RolePermission, Permission -from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST +from auth.orm import Author, Permission, Role, RolePermission from orm.community import Community, CommunityFollower, CommunityRole +from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") @@ -110,9 +110,7 @@ class ContextualPermissionCheck: return has_permission @staticmethod - def get_user_community_roles( - session: Session, author_id: int, community_slug: str - ) -> List[CommunityRole]: + def get_user_community_roles(session: Session, author_id: int, community_slug: str) -> List[CommunityRole]: """ Получает список ролей пользователя в сообществе. diff --git a/auth/sessions.py b/auth/sessions.py index 694ee38e..84692752 100644 --- a/auth/sessions.py +++ b/auth/sessions.py @@ -1,9 +1,10 @@ from datetime import datetime, timedelta, timezone -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel -from services.redis import redis + from auth.jwtcodec import JWTCodec, TokenPayload +from services.redis import redis from settings import SESSION_TOKEN_LIFE_SPAN from utils.logger import root_logger as logger @@ -28,11 +29,11 @@ class SessionManager: def _make_session_key(user_id: str, token: str) -> str: """ Создаёт ключ для сессии в Redis. - + Args: user_id: ID пользователя token: JWT токен сессии - + Returns: str: Ключ сессии """ @@ -44,10 +45,10 @@ class SessionManager: def _make_user_sessions_key(user_id: str) -> str: """ Создаёт ключ для списка активных сессий пользователя. - + Args: user_id: ID пользователя - + Returns: str: Ключ списка сессий """ @@ -57,12 +58,12 @@ class SessionManager: async def create_session(cls, user_id: str, username: str, device_info: Optional[dict] = None) -> str: """ Создаёт новую сессию. - + Args: user_id: ID пользователя username: Имя пользователя device_info: Информация об устройстве (опционально) - + Returns: str: JWT токен сессии """ @@ -96,37 +97,37 @@ class SessionManager: # Устанавливаем время жизни ключей (30 дней) pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60) - + # Также создаем ключ в формате, совместимом с TokenStorage для обратной совместимости token_key = f"{user_id}-{username}-{token}" pipeline.hset(token_key, mapping={"user_id": user_id, "username": username}) pipeline.expire(token_key, 30 * 24 * 60 * 60) - + result = await pipeline.execute() logger.info(f"[SessionManager.create_session] Сессия успешно создана для пользователя {user_id}") - + return token @classmethod async def verify_session(cls, token: str) -> Optional[TokenPayload]: """ Проверяет сессию по токену. - + Args: token: JWT токен - + Returns: Optional[TokenPayload]: Данные токена или None, если сессия недействительна """ logger.debug(f"[SessionManager.verify_session] Проверка сессии для токена: {token[:20]}...") - + # Декодируем токен для получения payload try: payload = JWTCodec.decode(token) if not payload: logger.error("[SessionManager.verify_session] Не удалось декодировать токен") return None - + logger.debug(f"[SessionManager.verify_session] Успешно декодирован токен, user_id={payload.user_id}") except Exception as e: logger.error(f"[SessionManager.verify_session] Ошибка при декодировании токена: {str(e)}") @@ -134,69 +135,71 @@ class SessionManager: # Получаем данные из payload user_id = payload.user_id - + # Формируем ключ сессии session_key = cls._make_session_key(user_id, token) logger.debug(f"[SessionManager.verify_session] Сформирован ключ сессии: {session_key}") - + # Проверяем существование сессии в Redis exists = await redis.exists(session_key) if not exists: logger.warning(f"[SessionManager.verify_session] Сессия не найдена: {user_id}. Ключ: {session_key}") - + # Проверяем также ключ в старом формате TokenStorage для обратной совместимости token_key = f"{user_id}-{payload.username}-{token}" old_format_exists = await redis.exists(token_key) - + if old_format_exists: logger.info(f"[SessionManager.verify_session] Найдена сессия в старом формате: {token_key}") - + # Миграция: создаем запись в новом формате session_data = { "user_id": user_id, "username": payload.username, } - + # Копируем сессию в новый формат pipeline = redis.pipeline() pipeline.hset(session_key, mapping=session_data) pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.sadd(cls._make_user_sessions_key(user_id), token) await pipeline.execute() - + logger.info(f"[SessionManager.verify_session] Сессия мигрирована в новый формат: {session_key}") return payload - + # Если сессия не найдена ни в новом, ни в старом формате, проверяем все ключи в Redis keys = await redis.keys("session:*") logger.debug(f"[SessionManager.verify_session] Все ключи сессий в Redis: {keys}") - + # Проверяем, можно ли доверять токену напрямую # Если токен валидный и не истек, мы можем доверять ему даже без записи в Redis if payload and payload.exp and payload.exp > datetime.now(tz=timezone.utc): logger.info(f"[SessionManager.verify_session] Токен валиден по JWT, создаем сессию для {user_id}") - + # Создаем сессию на основе валидного токена session_data = { "user_id": user_id, "username": payload.username, "created_at": datetime.now(tz=timezone.utc).isoformat(), - "expires_at": payload.exp.isoformat() if isinstance(payload.exp, datetime) else datetime.fromtimestamp(payload.exp, tz=timezone.utc).isoformat(), + "expires_at": payload.exp.isoformat() + if isinstance(payload.exp, datetime) + else datetime.fromtimestamp(payload.exp, tz=timezone.utc).isoformat(), } - + # Сохраняем сессию в Redis pipeline = redis.pipeline() pipeline.hset(session_key, mapping=session_data) pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.sadd(cls._make_user_sessions_key(user_id), token) await pipeline.execute() - + logger.info(f"[SessionManager.verify_session] Создана новая сессия для валидного токена: {session_key}") return payload - + # Если сессии нет, возвращаем None return None - + # Если сессия найдена, возвращаем payload logger.debug(f"[SessionManager.verify_session] Сессия найдена для пользователя {user_id}") return payload @@ -205,89 +208,89 @@ class SessionManager: async def get_user_sessions(cls, user_id: str) -> List[Dict[str, Any]]: """ Получает список активных сессий пользователя. - + Args: user_id: ID пользователя - + Returns: List[Dict[str, Any]]: Список сессий """ user_sessions_key = cls._make_user_sessions_key(user_id) tokens = await redis.smembers(user_sessions_key) - + sessions = [] for token in tokens: session_key = cls._make_session_key(user_id, token) session_data = await redis.hgetall(session_key) - + if session_data: session = dict(session_data) session["token"] = token sessions.append(session) - + return sessions @classmethod async def delete_session(cls, user_id: str, token: str) -> bool: """ Удаляет сессию. - + Args: user_id: ID пользователя token: JWT токен - + Returns: bool: True, если сессия успешно удалена """ session_key = cls._make_session_key(user_id, token) user_sessions_key = cls._make_user_sessions_key(user_id) - + # Удаляем данные сессии и токен из списка сессий пользователя pipeline = redis.pipeline() pipeline.delete(session_key) pipeline.srem(user_sessions_key, token) - + # Также удаляем ключ в формате TokenStorage для полной очистки token_payload = JWTCodec.decode(token) if token_payload: token_key = f"{user_id}-{token_payload.username}-{token}" pipeline.delete(token_key) - + results = await pipeline.execute() - + return bool(results[0]) or bool(results[1]) @classmethod async def delete_all_sessions(cls, user_id: str) -> int: """ Удаляет все сессии пользователя. - + Args: user_id: ID пользователя - + Returns: int: Количество удаленных сессий """ user_sessions_key = cls._make_user_sessions_key(user_id) tokens = await redis.smembers(user_sessions_key) - + count = 0 for token in tokens: session_key = cls._make_session_key(user_id, token) - + # Удаляем данные сессии deleted = await redis.delete(session_key) count += deleted - + # Также удаляем ключ в формате TokenStorage token_payload = JWTCodec.decode(token) if token_payload: token_key = f"{user_id}-{token_payload.username}-{token}" await redis.delete(token_key) - + # Очищаем список токенов await redis.delete(user_sessions_key) - + return count @classmethod diff --git a/auth/state.py b/auth/state.py index ecb638a7..6a9c7157 100644 --- a/auth/state.py +++ b/auth/state.py @@ -2,12 +2,13 @@ Классы состояния авторизации """ + class AuthState: """ Класс для хранения информации о состоянии авторизации пользователя. Используется в аутентификационных middleware и функциях. """ - + def __init__(self): self.logged_in = False self.author_id = None @@ -16,7 +17,7 @@ class AuthState: self.is_admin = False self.is_editor = False self.error = None - + def __bool__(self): """Возвращает True если пользователь авторизован""" - return self.logged_in \ No newline at end of file + return self.logged_in diff --git a/auth/tokenstorage.py b/auth/tokenstorage.py index b1895bfe..969e4668 100644 --- a/auth/tokenstorage.py +++ b/auth/tokenstorage.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta, timezone import json import time -from typing import Dict, Any, Optional, Tuple, List +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Tuple from auth.jwtcodec import JWTCodec from auth.validations import AuthInput @@ -81,7 +81,7 @@ class TokenStorage: # Формируем ключи для Redis token_key = cls._make_token_key(user_id, username, token) logger.debug(f"[TokenStorage.create_session] Сформированы ключи: token_key={token_key}") - + # Формируем ключи в новом формате SessionManager для совместимости session_key = cls._make_session_key(user_id, token) user_sessions_key = cls._make_user_sessions_key(user_id) @@ -91,25 +91,25 @@ class TokenStorage: "user_id": user_id, "username": username, "created_at": time.time(), - "expires_at": time.time() + 30 * 24 * 60 * 60 # 30 дней + "expires_at": time.time() + 30 * 24 * 60 * 60, # 30 дней } - + if device_info: token_data.update(device_info) - + logger.debug(f"[TokenStorage.create_session] Сформированы данные сессии: {token_data}") # Сохраняем в Redis старый формат pipeline = redis.pipeline() pipeline.hset(token_key, mapping=token_data) pipeline.expire(token_key, 30 * 24 * 60 * 60) # 30 дней - + # Также сохраняем в новом формате SessionManager для обеспечения совместимости pipeline.hset(session_key, mapping=token_data) pipeline.expire(session_key, 30 * 24 * 60 * 60) # 30 дней pipeline.sadd(user_sessions_key, token) pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60) # 30 дней - + results = await pipeline.execute() logger.info(f"[TokenStorage.create_session] Сессия успешно создана для пользователя {user_id}") @@ -146,39 +146,39 @@ class TokenStorage: if not payload: logger.warning(f"[TokenStorage.validate_token] Токен не валиден (не удалось декодировать)") return False, None - + user_id = payload.user_id username = payload.username - + # Формируем ключи для Redis в обоих форматах token_key = cls._make_token_key(user_id, username, token) session_key = cls._make_session_key(user_id, token) - + # Проверяем в обоих форматах для совместимости old_exists = await redis.exists(token_key) new_exists = await redis.exists(session_key) - + if old_exists or new_exists: logger.info(f"[TokenStorage.validate_token] Токен валиден для пользователя {user_id}") - + # Получаем данные токена из актуального хранилища if new_exists: token_data = await redis.hgetall(session_key) else: token_data = await redis.hgetall(token_key) - + # Если найден только в старом формате, создаем запись в новом формате if not new_exists: logger.info(f"[TokenStorage.validate_token] Миграция токена в новый формат: {session_key}") await redis.hset(session_key, mapping=token_data) await redis.expire(session_key, 30 * 24 * 60 * 60) await redis.sadd(cls._make_user_sessions_key(user_id), token) - + return True, token_data else: logger.warning(f"[TokenStorage.validate_token] Токен не найден в Redis: {token_key}") return False, None - + except Exception as e: logger.error(f"[TokenStorage.validate_token] Ошибка при проверке токена: {e}") return False, None @@ -200,30 +200,30 @@ class TokenStorage: if not payload: logger.warning(f"[TokenStorage.invalidate_token] Токен не валиден (не удалось декодировать)") return False - + user_id = payload.user_id username = payload.username - + # Формируем ключи для Redis в обоих форматах token_key = cls._make_token_key(user_id, username, token) session_key = cls._make_session_key(user_id, token) user_sessions_key = cls._make_user_sessions_key(user_id) - + # Удаляем токен из Redis в обоих форматах pipeline = redis.pipeline() pipeline.delete(token_key) pipeline.delete(session_key) pipeline.srem(user_sessions_key, token) results = await pipeline.execute() - + success = any(results) if success: logger.info(f"[TokenStorage.invalidate_token] Токен успешно инвалидирован для пользователя {user_id}") else: logger.warning(f"[TokenStorage.invalidate_token] Токен не найден: {token_key}") - + return success - + except Exception as e: logger.error(f"[TokenStorage.invalidate_token] Ошибка при инвалидации токена: {e}") return False @@ -243,11 +243,11 @@ class TokenStorage: # Получаем список сессий пользователя user_sessions_key = cls._make_user_sessions_key(user_id) tokens = await redis.smembers(user_sessions_key) - + if not tokens: logger.warning(f"[TokenStorage.invalidate_all_tokens] Нет активных сессий пользователя {user_id}") return 0 - + count = 0 for token in tokens: # Декодируем JWT токен @@ -255,28 +255,28 @@ class TokenStorage: payload = JWTCodec.decode(token) if payload: username = payload.username - + # Формируем ключи для Redis token_key = cls._make_token_key(user_id, username, token) session_key = cls._make_session_key(user_id, token) - + # Удаляем токен из Redis pipeline = redis.pipeline() pipeline.delete(token_key) pipeline.delete(session_key) results = await pipeline.execute() - + count += 1 except Exception as e: logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при обработке токена: {e}") continue - + # Удаляем список сессий пользователя await redis.delete(user_sessions_key) - + logger.info(f"[TokenStorage.invalidate_all_tokens] Инвалидировано {count} токенов пользователя {user_id}") return count - + except Exception as e: logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при инвалидации всех токенов: {e}") return 0 diff --git a/cache/precache.py b/cache/precache.py index 8871be7f..76f1c1a3 100644 --- a/cache/precache.py +++ b/cache/precache.py @@ -3,8 +3,8 @@ import json from sqlalchemy import and_, join, select -from cache.cache import cache_author, cache_topic from auth.orm import Author, AuthorFollower +from cache.cache import cache_author, cache_topic from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower, ShoutTopic from orm.topic import Topic, TopicFollower from resolvers.stat import get_with_stat @@ -29,9 +29,7 @@ async def precache_authors_followers(author_id, session): async def precache_authors_follows(author_id, session): follows_topics_query = select(TopicFollower.topic).where(TopicFollower.follower == author_id) follows_authors_query = select(AuthorFollower.author).where(AuthorFollower.follower == author_id) - follows_shouts_query = select(ShoutReactionsFollower.shout).where( - ShoutReactionsFollower.follower == author_id - ) + follows_shouts_query = select(ShoutReactionsFollower.shout).where(ShoutReactionsFollower.follower == author_id) follows_topics = {row[0] for row in session.execute(follows_topics_query) if row[0]} follows_authors = {row[0] for row in session.execute(follows_authors_query) if row[0]} diff --git a/cache/triggers.py b/cache/triggers.py index 647acc91..22e451d8 100644 --- a/cache/triggers.py +++ b/cache/triggers.py @@ -1,7 +1,7 @@ from sqlalchemy import event -from cache.revalidator import revalidation_manager from auth.orm import Author, AuthorFollower +from cache.revalidator import revalidation_manager from orm.reaction import Reaction, ReactionKind from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower from orm.topic import Topic, TopicFollower diff --git a/dev.py b/dev.py index 984769ba..05da3979 100644 --- a/dev.py +++ b/dev.py @@ -1,17 +1,19 @@ import os import subprocess from pathlib import Path -from utils.logger import root_logger as logger + from granian import Granian +from utils.logger import root_logger as logger + def check_mkcert_installed(): """ Проверяет, установлен ли инструмент mkcert в системе - + Returns: bool: True если mkcert установлен, иначе False - + >>> check_mkcert_installed() # doctest: +SKIP True """ @@ -21,18 +23,19 @@ def check_mkcert_installed(): except FileNotFoundError: return False + def generate_certificates(domain="localhost", cert_file="localhost.pem", key_file="localhost-key.pem"): """ Генерирует сертификаты с использованием mkcert - + Args: domain: Домен для сертификата cert_file: Имя файла сертификата key_file: Имя файла ключа - + Returns: tuple: (cert_file, key_file) пути к созданным файлам - + >>> generate_certificates() # doctest: +SKIP ('localhost.pem', 'localhost-key.pem') """ @@ -40,7 +43,7 @@ def generate_certificates(domain="localhost", cert_file="localhost.pem", key_fil if os.path.exists(cert_file) and os.path.exists(key_file): logger.info(f"Сертификаты уже существуют: {cert_file}, {key_file}") return cert_file, key_file - + # Проверяем, установлен ли mkcert if not check_mkcert_installed(): logger.error("mkcert не установлен. Установите mkcert с помощью команды:") @@ -49,37 +52,38 @@ def generate_certificates(domain="localhost", cert_file="localhost.pem", key_fil logger.error(" Windows: choco install mkcert") logger.error("После установки выполните: mkcert -install") return None, None - + try: # Запускаем mkcert для создания сертификата logger.info(f"Создание сертификатов для {domain} с помощью mkcert...") result = subprocess.run( ["mkcert", "-cert-file", cert_file, "-key-file", key_file, domain], - stdout=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True + text=True, ) - + if result.returncode != 0: logger.error(f"Ошибка при создании сертификатов: {result.stderr}") return None, None - + logger.info(f"Сертификаты созданы: {cert_file}, {key_file}") return cert_file, key_file except Exception as e: logger.error(f"Не удалось создать сертификаты: {str(e)}") return None, None + def run_server(host="0.0.0.0", port=8000, workers=1): """ Запускает сервер Granian с поддержкой HTTPS при необходимости - + Args: host: Хост для запуска сервера port: Порт для запуска сервера use_https: Флаг использования HTTPS workers: Количество рабочих процессов - + >>> run_server(use_https=True) # doctest: +SKIP """ # Проблема с многопроцессорным режимом - не поддерживает локальные объекты приложений @@ -87,16 +91,16 @@ def run_server(host="0.0.0.0", port=8000, workers=1): if workers > 1: logger.warning("Многопроцессорный режим может вызвать проблемы сериализации приложения. Использую 1 процесс.") workers = 1 - + # При проблемах с ASGI можно попробовать использовать Uvicorn как запасной вариант try: # Генерируем сертификаты с помощью mkcert cert_file, key_file = generate_certificates() - + if not cert_file or not key_file: logger.error("Не удалось сгенерировать сертификаты для HTTPS") return - + logger.info(f"Запуск HTTPS сервера на https://{host}:{port} с использованием Granian") # Запускаем Granian сервер с явным указанием ASGI server = Granian( @@ -104,7 +108,7 @@ def run_server(host="0.0.0.0", port=8000, workers=1): port=port, workers=workers, interface="asgi", - target="main:app", + target="main:app", ssl_cert=Path(cert_file), ssl_key=Path(key_file), ) @@ -113,5 +117,6 @@ def run_server(host="0.0.0.0", port=8000, workers=1): # В случае проблем с Granian, пробуем запустить через Uvicorn logger.error(f"Ошибка при запуске Granian: {str(e)}") + if __name__ == "__main__": - run_server() \ No newline at end of file + run_server() diff --git a/main.py b/main.py index cb71791e..ae799469 100644 --- a/main.py +++ b/main.py @@ -5,19 +5,18 @@ from os.path import exists, join from ariadne import load_schema_from_path, make_executable_schema from ariadne.asgi import GraphQL +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Mount, Route +from starlette.staticfiles import StaticFiles from auth.handler import EnhancedGraphQLHTTPHandler from auth.internal import InternalAuthentication -from auth.middleware import auth_middleware, AuthMiddleware -from starlette.applications import Starlette -from starlette.middleware.cors import CORSMiddleware -from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.middleware import Middleware -from starlette.requests import Request -from starlette.responses import JSONResponse, Response -from starlette.routing import Route, Mount -from starlette.staticfiles import StaticFiles - +from auth.middleware import AuthMiddleware, auth_middleware from cache.precache import precache_data from cache.revalidator import revalidation_manager from services.exception import ExceptionHandlerMiddleware @@ -25,8 +24,8 @@ from services.redis import redis from services.schema import create_all_tables, resolvers from services.search import check_search_service, initialize_search_index_background, search_service from services.viewed import ViewedStorage -from utils.logger import root_logger as logger from settings import DEV_SERVER_PID_FILE_NAME +from utils.logger import root_logger as logger DEVMODE = os.getenv("DOKKU_APP_TYPE", "false").lower() == "false" DIST_DIR = join(os.path.dirname(__file__), "dist") # Директория для собранных файлов @@ -46,14 +45,14 @@ middleware = [ Middleware( CORSMiddleware, allow_origins=[ - "https://localhost:3000", - "https://testing.discours.io", - "https://discours.io", + "https://localhost:3000", + "https://testing.discours.io", + "https://discours.io", "https://new.discours.io", "https://discours.ru", - "https://new.discours.ru" - ], - allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS + "https://new.discours.ru", + ], + allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS allow_headers=["*"], allow_credentials=True, ), @@ -65,33 +64,29 @@ middleware = [ # Создаем экземпляр GraphQL с улучшенным обработчиком -graphql_app = GraphQL( - schema, - debug=DEVMODE, - http_handler=EnhancedGraphQLHTTPHandler() -) +graphql_app = GraphQL(schema, debug=DEVMODE, http_handler=EnhancedGraphQLHTTPHandler()) # Оборачиваем GraphQL-обработчик для лучшей обработки ошибок async def graphql_handler(request: Request): """ Обработчик GraphQL запросов с поддержкой middleware и обработкой ошибок. - + Выполняет: 1. Проверку метода запроса (GET, POST, OPTIONS) 2. Обработку GraphQL запроса через ariadne 3. Применение middleware для корректной обработки cookie и авторизации 4. Обработку исключений и формирование ответа - + Args: request: Starlette Request объект - + Returns: Response: объект ответа (обычно JSONResponse) """ if request.method not in ["GET", "POST", "OPTIONS"]: return JSONResponse({"error": "Method Not Allowed by main.py"}, status_code=405) - + # Проверяем, что все необходимые middleware корректно отработали if not hasattr(request, "scope") or "auth" not in request.scope: logger.warning("[graphql] AuthMiddleware не обработал запрос перед GraphQL обработчиком") @@ -99,7 +94,7 @@ async def graphql_handler(request: Request): try: # Обрабатываем запрос через GraphQL приложение result = await graphql_app.handle_request(request) - + # Применяем middleware для установки cookie # Используем метод process_result из auth_middleware для корректной обработки # cookie на основе результатов операций login/logout @@ -111,6 +106,7 @@ async def graphql_handler(request: Request): logger.error(f"GraphQL error: {str(e)}") # Логируем более подробную информацию для отладки import traceback + logger.debug(f"GraphQL error traceback: {traceback.format_exc()}") return JSONResponse({"error": str(e)}, status_code=500) @@ -127,6 +123,7 @@ async def shutdown(): # Удаляем PID-файл, если он существует from settings import DEV_SERVER_PID_FILE_NAME + if exists(DEV_SERVER_PID_FILE_NAME): os.unlink(DEV_SERVER_PID_FILE_NAME) @@ -134,12 +131,12 @@ async def shutdown(): async def dev_start(): """ Инициализация сервера в DEV режиме. - + Функция: 1. Проверяет наличие DEV режима 2. Создает PID-файл для отслеживания процесса 3. Логирует информацию о старте сервера - + Используется только при запуске сервера с флагом "dev". """ try: @@ -151,6 +148,7 @@ async def dev_start(): old_pid = int(f.read().strip()) # Проверяем, существует ли процесс с таким PID import signal + try: os.kill(old_pid, 0) # Сигнал 0 только проверяет существование процесса print(f"[warning] DEV server already running with PID {old_pid}") @@ -158,7 +156,7 @@ async def dev_start(): print(f"[info] Stale PID file found, previous process {old_pid} not running") except (ValueError, FileNotFoundError): print(f"[warning] Invalid PID file found, recreating") - + # Создаем или перезаписываем PID-файл with open(pid_path, "w", encoding="utf-8") as f: f.write(str(os.getpid())) @@ -172,16 +170,16 @@ async def dev_start(): async def lifespan(_app): """ Функция жизненного цикла приложения. - + Обеспечивает: 1. Инициализацию всех необходимых сервисов и компонентов 2. Предзагрузку кеша данных 3. Подключение к Redis и поисковому сервису 4. Корректное завершение работы при остановке сервера - + Args: _app: экземпляр Starlette приложения - + Yields: None: генератор для управления жизненным циклом """ @@ -213,11 +211,12 @@ async def lifespan(_app): await asyncio.gather(*tasks, return_exceptions=True) print("[lifespan] Shutdown complete") + # Обновляем маршрут в Starlette app = Starlette( routes=[ Route("/graphql", graphql_handler, methods=["GET", "POST", "OPTIONS"]), - Mount("/", app=StaticFiles(directory=DIST_DIR, html=True)) + Mount("/", app=StaticFiles(directory=DIST_DIR, html=True)), ], lifespan=lifespan, middleware=middleware, # Явно указываем список middleware diff --git a/orm/community.py b/orm/community.py index 0aac4172..f7613c2f 100644 --- a/orm/community.py +++ b/orm/community.py @@ -66,11 +66,7 @@ class CommunityStats: def shouts(self): from orm.shout import Shout - return ( - self.community.session.query(func.count(Shout.id)) - .filter(Shout.community == self.community.id) - .scalar() - ) + return self.community.session.query(func.count(Shout.id)).filter(Shout.community == self.community.id).scalar() @property def followers(self): diff --git a/orm/shout.py b/orm/shout.py index 5934d6cb..b98126be 100644 --- a/orm/shout.py +++ b/orm/shout.py @@ -77,7 +77,7 @@ class Shout(Base): slug (str) cover (str) : "Cover image url" cover_caption (str) : "Cover image alt caption" - lead (str) + lead (str) title (str) subtitle (str) layout (str) diff --git a/resolvers/__init__.py b/resolvers/__init__.py index c63cf01a..925093a2 100644 --- a/resolvers/__init__.py +++ b/resolvers/__init__.py @@ -1,4 +1,15 @@ from cache.triggers import events_register +from resolvers.admin import ( + admin_get_roles, + admin_get_users, +) +from resolvers.auth import ( + confirm_email, + get_current_user, + login, + register_by_email, + send_link, +) from resolvers.author import ( # search_authors, get_author, get_author_followers, @@ -16,8 +27,8 @@ from resolvers.draft import ( delete_draft, load_drafts, publish_draft, - update_draft, unpublish_draft, + update_draft, ) from resolvers.editor import ( unpublish_shout, @@ -62,19 +73,6 @@ from resolvers.topic import ( get_topics_by_community, ) -from resolvers.auth import ( - get_current_user, - confirm_email, - register_by_email, - send_link, - login, -) - -from resolvers.admin import ( - admin_get_users, - admin_get_roles, -) - events_register() __all__ = [ @@ -84,11 +82,9 @@ __all__ = [ "register_by_email", "send_link", "login", - # admin "admin_get_users", "admin_get_roles", - # author "get_author", "get_author_followers", @@ -100,11 +96,9 @@ __all__ = [ "load_authors_search", "update_author", # "search_authors", - # community "get_community", "get_communities_all", - # topic "get_topic", "get_topics_all", @@ -112,14 +106,12 @@ __all__ = [ "get_topics_by_author", "get_topic_followers", "get_topic_authors", - # reader "get_shout", "load_shouts_by", "load_shouts_random_top", "load_shouts_search", "load_shouts_unrated", - # feed "load_shouts_feed", "load_shouts_coauthored", @@ -127,12 +119,10 @@ __all__ = [ "load_shouts_with_topic", "load_shouts_followed_by", "load_shouts_authored_by", - # follower "follow", "unfollow", "get_shout_followers", - # reaction "create_reaction", "update_reaction", @@ -142,18 +132,15 @@ __all__ = [ "load_shout_ratings", "load_comment_ratings", "load_comments_branch", - # notifier "load_notifications", "notifications_seen_thread", "notifications_seen_after", "notification_mark_seen", - # rating "rate_author", "get_my_rates_comments", "get_my_rates_shouts", - # draft "load_drafts", "create_draft", diff --git a/resolvers/admin.py b/resolvers/admin.py index 03ac1dbc..02c21126 100644 --- a/resolvers/admin.py +++ b/resolvers/admin.py @@ -1,12 +1,13 @@ from math import ceil -from sqlalchemy import or_, cast, String + from graphql.error import GraphQLError +from sqlalchemy import String, cast, or_ from auth.decorators import admin_auth_required +from auth.orm import Author, AuthorRole, Role from services.db import local_session -from services.schema import query, mutation -from auth.orm import Author, Role, AuthorRole from services.env import EnvManager, EnvVariable +from services.schema import mutation, query from utils.logger import root_logger as logger @@ -64,11 +65,9 @@ async def admin_get_users(_, info, limit=10, offset=0, search=None): "email": user.email, "name": user.name, "slug": user.slug, - "roles": [role.id for role in user.roles] - if hasattr(user, "roles") and user.roles - else [], + "roles": [role.id for role in user.roles] if hasattr(user, "roles") and user.roles else [], "created_at": user.created_at, - "last_seen": user.last_seen + "last_seen": user.last_seen, } for user in users ], @@ -81,6 +80,7 @@ async def admin_get_users(_, info, limit=10, offset=0, search=None): return result except Exception as e: import traceback + logger.error(f"Ошибка при получении списка пользователей: {str(e)}") logger.error(traceback.format_exc()) raise GraphQLError(f"Не удалось получить список пользователей: {str(e)}") @@ -126,20 +126,20 @@ async def admin_get_roles(_, info): async def get_env_variables(_, info): """ Получает список переменных окружения, сгруппированных по секциям - + Args: info: Контекст GraphQL запроса - + Returns: Список секций с переменными окружения """ try: # Создаем экземпляр менеджера переменных окружения env_manager = EnvManager() - + # Получаем все переменные sections = env_manager.get_all_variables() - + # Преобразуем к формату GraphQL API result = [ { @@ -154,11 +154,11 @@ async def get_env_variables(_, info): "isSecret": var.is_secret, } for var in section.variables - ] + ], } for section in sections ] - + return result except Exception as e: logger.error(f"Ошибка при получении переменных окружения: {str(e)}") @@ -170,27 +170,27 @@ async def get_env_variables(_, info): async def update_env_variable(_, info, key, value): """ Обновляет значение переменной окружения - + Args: info: Контекст GraphQL запроса key: Ключ переменной value: Новое значение - + Returns: Boolean: результат операции """ try: # Создаем экземпляр менеджера переменных окружения env_manager = EnvManager() - + # Обновляем переменную result = env_manager.update_variable(key, value) - + if result: logger.info(f"Переменная окружения '{key}' успешно обновлена") else: logger.error(f"Не удалось обновить переменную окружения '{key}'") - + return result except Exception as e: logger.error(f"Ошибка при обновлении переменной окружения: {str(e)}") @@ -202,36 +202,32 @@ async def update_env_variable(_, info, key, value): async def update_env_variables(_, info, variables): """ Массовое обновление переменных окружения - + Args: info: Контекст GraphQL запроса variables: Список переменных для обновления - + Returns: Boolean: результат операции """ try: # Создаем экземпляр менеджера переменных окружения env_manager = EnvManager() - + # Преобразуем входные данные в формат для менеджера env_variables = [ - EnvVariable( - key=var.get("key", ""), - value=var.get("value", ""), - type=var.get("type", "string") - ) + EnvVariable(key=var.get("key", ""), value=var.get("value", ""), type=var.get("type", "string")) for var in variables ] - + # Обновляем переменные result = env_manager.update_variables(env_variables) - + if result: logger.info(f"Переменные окружения успешно обновлены ({len(variables)} шт.)") else: logger.error(f"Не удалось обновить переменные окружения") - + return result except Exception as e: logger.error(f"Ошибка при массовом обновлении переменных окружения: {str(e)}") @@ -243,90 +239,78 @@ async def update_env_variables(_, info, variables): async def admin_update_user(_, info, user): """ Обновляет роли пользователя - + Args: info: Контекст GraphQL запроса user: Данные для обновления пользователя (содержит id и roles) - + Returns: Boolean: результат операции или объект с ошибкой """ try: user_id = user.get("id") roles = user.get("roles", []) - + if not roles: logger.warning(f"Пользователю {user_id} не назначено ни одной роли. Доступ в систему будет заблокирован.") - + with local_session() as session: # Получаем пользователя из базы данных author = session.query(Author).filter(Author.id == user_id).first() - + if not author: error_msg = f"Пользователь с ID {user_id} не найден" logger.error(error_msg) - return { - "success": False, - "error": error_msg - } - + return {"success": False, "error": error_msg} + # Получаем ID сообщества по умолчанию default_community_id = 1 # Используем значение по умолчанию из модели AuthorRole - + try: # Очищаем текущие роли пользователя через ORM session.query(AuthorRole).filter(AuthorRole.author == user_id).delete() session.flush() - + # Получаем все существующие роли, которые указаны для обновления role_objects = session.query(Role).filter(Role.id.in_(roles)).all() - + # Проверяем, все ли запрошенные роли найдены found_role_ids = [role.id for role in role_objects] missing_roles = set(roles) - set(found_role_ids) - + if missing_roles: warning_msg = f"Некоторые роли не найдены в базе: {', '.join(missing_roles)}" logger.warning(warning_msg) - + # Создаем новые записи в таблице author_role с указанием community for role in role_objects: # Используем ORM для создания новых записей - author_role = AuthorRole( - community=default_community_id, - author=user_id, - role=role.id - ) + author_role = AuthorRole(community=default_community_id, author=user_id, role=role.id) session.add(author_role) - + # Сохраняем изменения в базе данных session.commit() - + # Проверяем, добавлена ли пользователю роль reader - has_reader = 'reader' in [role.id for role in role_objects] + has_reader = "reader" in [role.id for role in role_objects] if not has_reader: - logger.warning(f"Пользователю {author.email or author.id} не назначена роль 'reader'. Доступ в систему будет ограничен.") - + logger.warning( + f"Пользователю {author.email or author.id} не назначена роль 'reader'. Доступ в систему будет ограничен." + ) + logger.info(f"Роли пользователя {author.email or author.id} обновлены: {', '.join(found_role_ids)}") - - return { - "success": True - } + + return {"success": True} except Exception as e: # Обработка вложенных исключений session.rollback() error_msg = f"Ошибка при изменении ролей: {str(e)}" logger.error(error_msg) - return { - "success": False, - "error": error_msg - } + return {"success": False, "error": error_msg} except Exception as e: import traceback + error_msg = f"Ошибка при обновлении ролей пользователя: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) - return { - "success": False, - "error": error_msg - } + return {"success": False, "error": error_msg} diff --git a/resolvers/auth.py b/resolvers/auth.py index 6eaf2327..0608e04c 100644 --- a/resolvers/auth.py +++ b/resolvers/auth.py @@ -1,46 +1,48 @@ # -*- coding: utf-8 -*- import time import traceback -from utils.logger import root_logger as logger from graphql.type import GraphQLResolveInfo -# import asyncio # Убираем, так как резолвер будет синхронным -from services.auth import login_required from auth.credentials import AuthCredentials from auth.email import send_auth_email from auth.exceptions import InvalidToken, ObjectNotExist from auth.identity import Identity, Password +from auth.internal import verify_internal_auth from auth.jwtcodec import JWTCodec -from auth.tokenstorage import TokenStorage from auth.orm import Author, Role +from auth.sessions import SessionManager +from auth.tokenstorage import TokenStorage + +# import asyncio # Убираем, так как резолвер будет синхронным +from services.auth import login_required from services.db import local_session from services.schema import mutation, query from settings import ( ADMIN_EMAILS, - SESSION_TOKEN_HEADER, - SESSION_COOKIE_NAME, - SESSION_COOKIE_SECURE, - SESSION_COOKIE_SAMESITE, - SESSION_COOKIE_MAX_AGE, SESSION_COOKIE_HTTPONLY, + SESSION_COOKIE_MAX_AGE, + SESSION_COOKIE_NAME, + SESSION_COOKIE_SAMESITE, + SESSION_COOKIE_SECURE, + SESSION_TOKEN_HEADER, ) from utils.generate_slug import generate_unique_slug -from auth.sessions import SessionManager -from auth.internal import verify_internal_auth +from utils.logger import root_logger as logger + @mutation.field("getSession") @login_required async def get_current_user(_, info): """ Получает информацию о текущем пользователе. - + Требует авторизации через декоратор login_required. - + Args: _: Родительский объект (не используется) info: Контекст GraphQL запроса - + Returns: dict: Объект с токеном и данными автора с добавленной статистикой """ @@ -49,68 +51,73 @@ async def get_current_user(_, info): if not author_id: logger.error("[getSession] Пользователь не авторизован") from graphql.error import GraphQLError + raise GraphQLError("Требуется авторизация") - + # Получаем токен из заголовка req = info.context.get("request") token = req.headers.get(SESSION_TOKEN_HEADER) if token and token.startswith("Bearer "): token = token.split("Bearer ")[-1].strip() - + # Получаем данные автора author = info.context.get("author") - + # Если автор не найден в контексте, пробуем получить из БД с добавлением статистики if not author: - logger.debug(f"[getSession] Автор не найден в контексте для пользователя {user_id}, получаем из БД") - + logger.debug(f"[getSession] Автор не найден в контексте для пользователя {author_id}, получаем из БД") + try: # Используем функцию get_with_stat для получения автора со статистикой from sqlalchemy import select + from resolvers.stat import get_with_stat - - q = select(Author).where(Author.id == user_id) + + q = select(Author).where(Author.id == author_id) authors_with_stat = get_with_stat(q) - + if authors_with_stat and len(authors_with_stat) > 0: author = authors_with_stat[0] - + # Обновляем last_seen отдельной транзакцией with local_session() as session: - author_db = session.query(Author).filter(Author.id == user_id).first() + author_db = session.query(Author).filter(Author.id == author_id).first() if author_db: author_db.last_seen = int(time.time()) session.commit() else: - logger.error(f"[getSession] Автор с ID {user_id} не найден в БД") + logger.error(f"[getSession] Автор с ID {author_id} не найден в БД") from graphql.error import GraphQLError + raise GraphQLError("Пользователь не найден") - + except Exception as e: logger.error(f"[getSession] Ошибка при получении автора из БД: {e}", exc_info=True) from graphql.error import GraphQLError + raise GraphQLError("Ошибка при получении данных пользователя") else: # Если автор уже есть в контексте, добавляем статистику try: from sqlalchemy import select + from resolvers.stat import get_with_stat - - q = select(Author).where(Author.id == user_id) + + q = select(Author).where(Author.id == author_id) authors_with_stat = get_with_stat(q) - + if authors_with_stat and len(authors_with_stat) > 0: # Обновляем только статистику author.stat = authors_with_stat[0].stat except Exception as e: logger.warning(f"[getSession] Не удалось добавить статистику к автору: {e}") - + # Возвращаем данные сессии - logger.info(f"[getSession] Успешно получена сессия для пользователя {user_id}") - return {"token": token or '', "author": author} + logger.info(f"[getSession] Успешно получена сессия для пользователя {author_id}") + return {"token": token or "", "author": author} -@mutation.field("confirmEmail") +@mutation.field("confirmEmail") async def confirm_email(_, info, token): """confirm owning email address""" try: @@ -118,26 +125,26 @@ async def confirm_email(_, info, token): payload = JWTCodec.decode(token) user_id = payload.user_id username = payload.username - + # Если TokenStorage.get асинхронный, это нужно будет переделать или вызывать синхронно # Для теста пока оставим, но это потенциальная точка отказа в синхронном резолвере token_key = f"{user_id}-{username}-{token}" await TokenStorage.get(token_key) - + with local_session() as session: user = session.query(Author).where(Author.id == user_id).first() if not user: logger.warning(f"[auth] confirmEmail: Пользователь с ID {user_id} не найден.") return {"success": False, "token": None, "author": None, "error": "Пользователь не найден"} - + # Создаем сессионный токен с новым форматом вызова и явным временем истечения device_info = {"email": user.email} if hasattr(user, "email") else None session_token = await TokenStorage.create_session( user_id=str(user_id), username=user.username or user.email or user.slug or username, - device_info=device_info + device_info=device_info, ) - + user.email_verified = True user.last_seen = int(time.time()) session.add(user) @@ -155,7 +162,7 @@ async def confirm_email(_, info, token): "token": None, "author": None, "error": f"Ошибка подтверждения email: {str(e)}", - } + } def create_user(user_dict): @@ -231,9 +238,7 @@ async def register_by_email(_, _info, email: str, password: str = "", name: str try: # Если auth_send_link асинхронный... await send_link(_, _info, email) - logger.info( - f"[auth] registerUser: Пользователь {email} зарегистрирован, ссылка для подтверждения отправлена." - ) + logger.info(f"[auth] registerUser: Пользователь {email} зарегистрирован, ссылка для подтверждения отправлена.") # При регистрации возвращаем данные самому пользователю, поэтому не фильтруем return { "success": True, @@ -306,7 +311,7 @@ async def login(_, info, email: str, password: str): logger.info( f"[auth] login: Найден автор {email}, id={author.id}, имя={author.name}, пароль есть: {bool(author.password)}" ) - + # Проверяем наличие роли reader has_reader_role = False if hasattr(author, "roles") and author.roles: @@ -314,12 +319,12 @@ async def login(_, info, email: str, password: str): if role.id == "reader": has_reader_role = True break - + # Если у пользователя нет роли reader и он не админ, запрещаем вход if not has_reader_role: # Проверяем, есть ли роль admin или super is_admin = author.email in ADMIN_EMAILS.split(",") - + if not is_admin: logger.warning(f"[auth] login: У пользователя {email} нет роли 'reader', в доступе отказано") return { @@ -365,9 +370,7 @@ async def login(_, info, email: str, password: str): or not hasattr(valid_author, "username") and not hasattr(valid_author, "email") ): - logger.error( - f"[auth] login: Объект автора не содержит необходимых атрибутов: {valid_author}" - ) + logger.error(f"[auth] login: Объект автора не содержит необходимых атрибутов: {valid_author}") return { "success": False, "token": None, @@ -380,7 +383,7 @@ async def login(_, info, email: str, password: str): token = await TokenStorage.create_session( user_id=str(valid_author.id), username=valid_author.username or valid_author.email or valid_author.slug or "", - device_info={"email": valid_author.email} if hasattr(valid_author, "email") else None + device_info={"email": valid_author.email} if hasattr(valid_author, "email") else None, ) logger.info(f"[auth] login: токен успешно создан, длина: {len(token) if token else 0}") @@ -390,7 +393,7 @@ async def login(_, info, email: str, password: str): # Устанавливаем httponly cookie различными способами для надежности cookie_set = False - + # Метод 1: GraphQL контекст через extensions try: if hasattr(info.context, "extensions") and hasattr(info.context.extensions, "set_cookie"): @@ -406,7 +409,7 @@ async def login(_, info, email: str, password: str): cookie_set = True except Exception as e: logger.error(f"[auth] login: Ошибка при установке cookie через extensions: {str(e)}") - + # Метод 2: GraphQL контекст через response if not cookie_set: try: @@ -423,11 +426,12 @@ async def login(_, info, email: str, password: str): cookie_set = True except Exception as e: logger.error(f"[auth] login: Ошибка при установке cookie через response: {str(e)}") - + # Если ни один способ не сработал, создаем response в контексте if not cookie_set and hasattr(info.context, "request") and not hasattr(info.context, "response"): try: from starlette.responses import JSONResponse + response = JSONResponse({}) response.set_cookie( key=SESSION_COOKIE_NAME, @@ -442,12 +446,12 @@ async def login(_, info, email: str, password: str): cookie_set = True except Exception as e: logger.error(f"[auth] login: Ошибка при создании response и установке cookie: {str(e)}") - + if not cookie_set: logger.warning(f"[auth] login: Не удалось установить cookie никаким способом") - + # Возвращаем успешный результат с данными для клиента - # Для ответа клиенту используем dict() с параметром access=True, + # Для ответа клиенту используем dict() с параметром access=True, # чтобы получить полный доступ к данным для самого пользователя logger.info(f"[auth] login: Успешный вход для {email}") author_dict = valid_author.dict(access=True) @@ -485,7 +489,7 @@ async def is_email_used(_, _info, email): async def logout_resolver(_, info: GraphQLResolveInfo): """ Выход из системы через GraphQL с удалением сессии и cookie. - + Returns: dict: Результат операции выхода """ @@ -500,7 +504,7 @@ async def logout_resolver(_, info: GraphQLResolveInfo): success = False message = "" - + # Если токен найден, отзываем его if token: try: @@ -544,12 +548,12 @@ async def logout_resolver(_, info: GraphQLResolveInfo): async def refresh_token_resolver(_, info: GraphQLResolveInfo): """ Обновление токена аутентификации через GraphQL. - + Returns: AuthResult с данными пользователя и обновленным токеном или сообщением об ошибке """ request = info.context["request"] - + # Получаем текущий токен из cookie или заголовка token = request.cookies.get(SESSION_COOKIE_NAME) if not token: @@ -617,12 +621,7 @@ async def refresh_token_resolver(_, info: GraphQLResolveInfo): logger.debug(traceback.format_exc()) logger.info(f"[auth] refresh_token_resolver: Токен успешно обновлен для пользователя {user_id}") - return { - "success": True, - "token": new_token, - "author": author, - "error": None - } + return {"success": True, "token": new_token, "author": author, "error": None} except Exception as e: logger.error(f"[auth] refresh_token_resolver: Ошибка при обновлении токена: {e}") diff --git a/resolvers/author.py b/resolvers/author.py index 3adea6d8..53d9ba0b 100644 --- a/resolvers/author.py +++ b/resolvers/author.py @@ -1,9 +1,10 @@ import asyncio import time -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from sqlalchemy import select, text +from auth.orm import Author from cache.cache import ( cache_author, cached_query, @@ -13,7 +14,6 @@ from cache.cache import ( get_cached_follower_topics, invalidate_cache_by_prefix, ) -from auth.orm import Author from resolvers.stat import get_with_stat from services.auth import login_required from services.db import local_session @@ -74,27 +74,26 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c # Функция для получения авторов из БД async def fetch_authors_with_stats(): - logger.debug( - f"Выполняем запрос на получение авторов со статистикой: limit={limit}, offset={offset}, by={by}" - ) + logger.debug(f"Выполняем запрос на получение авторов со статистикой: limit={limit}, offset={offset}, by={by}") with local_session() as session: # Базовый запрос для получения авторов base_query = select(Author).where(Author.deleted_at.is_(None)) # Применяем сортировку - + # vars for statistics sorting stats_sort_field = None stats_sort_direction = "desc" - + if by: if isinstance(by, dict): logger.debug(f"Processing dict-based sorting: {by}") # Обработка словаря параметров сортировки from sqlalchemy import asc, desc, func - from orm.shout import ShoutAuthor + from auth.orm import AuthorFollower + from orm.shout import ShoutAuthor # Checking for order field in the dictionary if "order" in by: @@ -135,50 +134,40 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c # If sorting by statistics, modify the query if stats_sort_field == "shouts": # Sorting by the number of shouts - from sqlalchemy import func, and_ + from sqlalchemy import and_, func + from orm.shout import Shout, ShoutAuthor - + subquery = ( - select( - ShoutAuthor.author, - func.count(func.distinct(Shout.id)).label("shouts_count") - ) + select(ShoutAuthor.author, func.count(func.distinct(Shout.id)).label("shouts_count")) .select_from(ShoutAuthor) .join(Shout, ShoutAuthor.shout == Shout.id) - .where( - and_( - Shout.deleted_at.is_(None), - Shout.published_at.is_not(None) - ) - ) + .where(and_(Shout.deleted_at.is_(None), Shout.published_at.is_not(None))) .group_by(ShoutAuthor.author) .subquery() ) - - base_query = ( - base_query - .outerjoin(subquery, Author.id == subquery.c.author) - .order_by(desc(func.coalesce(subquery.c.shouts_count, 0))) + + base_query = base_query.outerjoin(subquery, Author.id == subquery.c.author).order_by( + desc(func.coalesce(subquery.c.shouts_count, 0)) ) elif stats_sort_field == "followers": # Sorting by the number of followers from sqlalchemy import func + from auth.orm import AuthorFollower - + subquery = ( select( AuthorFollower.author, - func.count(func.distinct(AuthorFollower.follower)).label("followers_count") + func.count(func.distinct(AuthorFollower.follower)).label("followers_count"), ) .select_from(AuthorFollower) .group_by(AuthorFollower.author) .subquery() ) - - base_query = ( - base_query - .outerjoin(subquery, Author.id == subquery.c.author) - .order_by(desc(func.coalesce(subquery.c.followers_count, 0))) + + base_query = base_query.outerjoin(subquery, Author.id == subquery.c.author).order_by( + desc(func.coalesce(subquery.c.followers_count, 0)) ) # Применяем лимит и смещение @@ -219,7 +208,7 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c "shouts": shouts_stats.get(author.id, 0), "followers": followers_stats.get(author.id, 0), } - + result.append(author_dict) # Кешируем каждого автора отдельно для использования в других функциях @@ -299,7 +288,7 @@ async def update_author(_, info, profile): # Кэшируем полную версию для админов author_dict = author_with_stat.dict(access=is_admin) asyncio.create_task(cache_author(author_dict)) - + # Возвращаем обычную полную версию, т.к. это владелец return {"error": None, "author": author} except Exception as exc: @@ -328,16 +317,16 @@ async def get_authors_all(_, info): async def get_author(_, info, slug="", author_id=0): # Получаем ID текущего пользователя и флаг админа из контекста is_admin = info.context.get("is_admin", False) - + author_dict = None try: author_id = get_author_id_from(slug=slug, user="", author_id=author_id) if not author_id: raise ValueError("cant find") - + # Получаем данные автора из кэша (полные данные) cached_author = await get_cached_author(int(author_id), get_with_stat) - + # Применяем фильтрацию на стороне клиента, так как в кэше хранится полная версия if cached_author: # Создаем объект автора для использования метода dict @@ -361,7 +350,7 @@ async def get_author(_, info, slug="", author_id=0): # Кэшируем полные данные для админов original_dict = author_with_stat.dict(access=True) asyncio.create_task(cache_author(original_dict)) - + # Возвращаем отфильтрованную версию author_dict = author_with_stat.dict(access=is_admin) # Добавляем статистику @@ -393,11 +382,12 @@ async def load_authors_by(_, info, by, limit, offset): # Получаем ID текущего пользователя и флаг админа из контекста viewer_id = info.context.get("author", {}).get("id") is_admin = info.context.get("is_admin", False) - + # Используем оптимизированную функцию для получения авторов return await get_authors_with_stats(limit, offset, by, viewer_id, is_admin) except Exception as exc: import traceback + logger.error(f"{exc}:\n{traceback.format_exc()}") return [] @@ -413,7 +403,7 @@ async def load_authors_search(_, info, text: str, limit: int = 10, offset: int = Returns: list: List of authors matching the search criteria """ - + # Get author IDs from search engine (already sorted by relevance) search_results = await search_service.search_authors(text, limit, offset) @@ -429,13 +419,13 @@ async def load_authors_search(_, info, text: str, limit: int = 10, offset: int = # Simple query to get authors by IDs - no need for stats here authors_query = select(Author).filter(Author.id.in_(author_ids)) db_authors = session.execute(authors_query).scalars().all() - + if not db_authors: return [] # Create a dictionary for quick lookup authors_dict = {str(author.id): author for author in db_authors} - + # Keep the order from search results (maintains the relevance sorting) ordered_authors = [authors_dict[author_id] for author_id in author_ids if author_id in authors_dict] @@ -468,7 +458,7 @@ async def get_author_follows(_, info, slug="", user=None, author_id=0): # Получаем ID текущего пользователя и флаг админа из контекста viewer_id = info.context.get("author", {}).get("id") is_admin = info.context.get("is_admin", False) - + logger.debug(f"getting follows for @{slug}") author_id = get_author_id_from(slug=slug, user=user, author_id=author_id) if not author_id: @@ -477,7 +467,7 @@ async def get_author_follows(_, info, slug="", user=None, author_id=0): # Получаем данные из кэша followed_authors_raw = await get_cached_follower_authors(author_id) followed_topics = await get_cached_follower_topics(author_id) - + # Фильтруем чувствительные данные авторов followed_authors = [] for author_data in followed_authors_raw: @@ -517,15 +507,14 @@ async def get_author_follows_authors(_, info, slug="", user=None, author_id=None # Получаем ID текущего пользователя и флаг админа из контекста viewer_id = info.context.get("author", {}).get("id") is_admin = info.context.get("is_admin", False) - + logger.debug(f"getting followed authors for @{slug}") if not author_id: return [] - # Получаем данные из кэша followed_authors_raw = await get_cached_follower_authors(author_id) - + # Фильтруем чувствительные данные авторов followed_authors = [] for author_data in followed_authors_raw: @@ -540,7 +529,7 @@ async def get_author_follows_authors(_, info, slug="", user=None, author_id=None # is_admin - булево значение, является ли текущий пользователь админом has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id)) followed_authors.append(temp_author.dict(access=has_access)) - + return followed_authors @@ -562,15 +551,15 @@ async def get_author_followers(_, info, slug: str = "", user: str = "", author_i # Получаем ID текущего пользователя и флаг админа из контекста viewer_id = info.context.get("author", {}).get("id") is_admin = info.context.get("is_admin", False) - + logger.debug(f"getting followers for author @{slug} or ID:{author_id}") author_id = get_author_id_from(slug=slug, user=user, author_id=author_id) if not author_id: return [] - + # Получаем данные из кэша followers_raw = await get_cached_author_followers(author_id) - + # Фильтруем чувствительные данные авторов followers = [] for follower_data in followers_raw: @@ -585,5 +574,5 @@ async def get_author_followers(_, info, slug: str = "", user: str = "", author_i # is_admin - булево значение, является ли текущий пользователь админом has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id)) followers.append(temp_author.dict(access=has_access)) - + return followers diff --git a/resolvers/bookmark.py b/resolvers/bookmark.py index 8d21a4ef..901b8fc0 100644 --- a/resolvers/bookmark.py +++ b/resolvers/bookmark.py @@ -72,9 +72,7 @@ def toggle_bookmark_shout(_, info, slug: str) -> CommonResult: if existing_bookmark: db.execute( - delete(AuthorBookmark).where( - AuthorBookmark.author == author_id, AuthorBookmark.shout == shout.id - ) + delete(AuthorBookmark).where(AuthorBookmark.author == author_id, AuthorBookmark.shout == shout.id) ) result = False else: diff --git a/resolvers/community.py b/resolvers/community.py index 12256e2b..7fa1bbad 100644 --- a/resolvers/community.py +++ b/resolvers/community.py @@ -74,9 +74,9 @@ async def update_community(_, info, community_data): if slug: with local_session() as session: try: - session.query(Community).where( - Community.created_by == author_id, Community.slug == slug - ).update(community_data) + session.query(Community).where(Community.created_by == author_id, Community.slug == slug).update( + community_data + ) session.commit() except Exception as e: return {"ok": False, "error": str(e)} @@ -90,9 +90,7 @@ async def delete_community(_, info, slug: str): author_id = author_dict.get("id") with local_session() as session: try: - session.query(Community).where( - Community.slug == slug, Community.created_by == author_id - ).delete() + session.query(Community).where(Community.slug == slug, Community.created_by == author_id).delete() session.commit() return {"ok": True} except Exception as e: diff --git a/resolvers/draft.py b/resolvers/draft.py index 8ee52bd3..a6b28f67 100644 --- a/resolvers/draft.py +++ b/resolvers/draft.py @@ -1,11 +1,12 @@ import time + from sqlalchemy.orm import joinedload +from auth.orm import Author from cache.cache import ( invalidate_shout_related_cache, invalidate_shouts_cache, ) -from auth.orm import Author from orm.draft import Draft, DraftAuthor, DraftTopic from orm.shout import Shout, ShoutAuthor, ShoutTopic from services.auth import login_required @@ -449,15 +450,15 @@ async def publish_draft(_, info, draft_id: int): # Добавляем темы for topic in draft.topics or []: - st = ShoutTopic( - topic=topic.id, shout=shout.id, main=topic.main if hasattr(topic, "main") else False - ) + st = ShoutTopic(topic=topic.id, shout=shout.id, main=topic.main if hasattr(topic, "main") else False) session.add(st) session.commit() # Инвалидируем кеш - cache_keys = [f"shouts:{shout.id}", ] + cache_keys = [ + f"shouts:{shout.id}", + ] await invalidate_shouts_cache(cache_keys) await invalidate_shout_related_cache(shout, author_id) @@ -482,67 +483,59 @@ async def publish_draft(_, info, draft_id: int): async def unpublish_draft(_, info, draft_id: int): """ Снимает с публикации черновик, обновляя связанный Shout. - + Args: draft_id (int): ID черновика, публикацию которого нужно снять - + Returns: dict: Результат операции с информацией о черновике или сообщением об ошибке """ author_dict = info.context.get("author", {}) author_id = author_dict.get("id") - + if author_id: return {"error": "Author ID is required"} - + try: with local_session() as session: # Загружаем черновик со связанной публикацией draft = ( session.query(Draft) - .options( - joinedload(Draft.publication), - joinedload(Draft.authors), - joinedload(Draft.topics) - ) + .options(joinedload(Draft.publication), joinedload(Draft.authors), joinedload(Draft.topics)) .filter(Draft.id == draft_id) .first() ) - + if not draft: return {"error": "Draft not found"} - + # Проверяем, есть ли публикация if not draft.publication: return {"error": "This draft is not published yet"} - + shout = draft.publication - + # Снимаем с публикации shout.published_at = None shout.updated_at = int(time.time()) shout.updated_by = author_id - + session.commit() - + # Инвалидируем кэш cache_keys = [f"shouts:{shout.id}"] await invalidate_shouts_cache(cache_keys) await invalidate_shout_related_cache(shout, author_id) - + # Формируем результат draft_dict = draft.dict() # Добавляем информацию о публикации - draft_dict["publication"] = { - "id": shout.id, - "slug": shout.slug, - "published_at": None - } - + draft_dict["publication"] = {"id": shout.id, "slug": shout.slug, "published_at": None} + logger.info(f"Successfully unpublished shout #{shout.id} for draft #{draft_id}") - + return {"draft": draft_dict} - + except Exception as e: logger.error(f"Failed to unpublish draft {draft_id}: {e}", exc_info=True) return {"error": f"Failed to unpublish draft: {str(e)}"} diff --git a/resolvers/editor.py b/resolvers/editor.py index 47b99dc9..41539c7d 100644 --- a/resolvers/editor.py +++ b/resolvers/editor.py @@ -5,13 +5,13 @@ from sqlalchemy import and_, desc, select from sqlalchemy.orm import joinedload, selectinload from sqlalchemy.sql.functions import coalesce +from auth.orm import Author from cache.cache import ( cache_author, cache_topic, invalidate_shout_related_cache, invalidate_shouts_cache, ) -from auth.orm import Author from orm.draft import Draft from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.topic import Topic @@ -179,9 +179,7 @@ async def create_shout(_, info, inp): lead = inp.get("lead", "") body_text = extract_text(body) lead_text = extract_text(lead) - seo = inp.get( - "seo", lead_text.strip() or body_text.strip()[:300].split(". ")[:-1].join(". ") - ) + seo = inp.get("seo", lead_text.strip() or body_text.strip()[:300].split(". ")[:-1].join(". ")) new_shout = Shout( slug=slug, body=body, @@ -278,9 +276,7 @@ def patch_main_topic(session, main_topic_slug, shout): with session.begin(): # Получаем текущий главный топик old_main = ( - session.query(ShoutTopic) - .filter(and_(ShoutTopic.shout == shout.id, ShoutTopic.main.is_(True))) - .first() + session.query(ShoutTopic).filter(and_(ShoutTopic.shout == shout.id, ShoutTopic.main.is_(True))).first() ) if old_main: logger.info(f"Found current main topic: {old_main.topic.slug}") @@ -314,9 +310,7 @@ def patch_main_topic(session, main_topic_slug, shout): session.flush() logger.info(f"Main topic updated for shout#{shout.id}") else: - logger.warning( - f"No changes needed for main topic (old={old_main is not None}, new={new_main is not None})" - ) + logger.warning(f"No changes needed for main topic (old={old_main is not None}, new={new_main is not None})") def patch_topics(session, shout, topics_input): @@ -410,9 +404,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False): logger.info(f"Processing update for shout#{shout_id} by author #{author_id}") shout_by_id = ( session.query(Shout) - .options( - joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors) - ) + .options(joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors)) .filter(Shout.id == shout_id) .first() ) @@ -441,10 +433,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False): shout_input["slug"] = slug logger.info(f"shout#{shout_id} slug patched") - if ( - filter(lambda x: x.id == author_id, [x for x in shout_by_id.authors]) - or "editor" in roles - ): + if filter(lambda x: x.id == author_id, [x for x in shout_by_id.authors]) or "editor" in roles: logger.info(f"Author #{author_id} has permission to edit shout#{shout_id}") # topics patch @@ -558,9 +547,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False): # Получаем полные данные шаута со связями shout_with_relations = ( session.query(Shout) - .options( - joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors) - ) + .options(joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors)) .filter(Shout.id == shout_id) .first() ) diff --git a/resolvers/feed.py b/resolvers/feed.py index 69a5fa55..9f8ebd79 100644 --- a/resolvers/feed.py +++ b/resolvers/feed.py @@ -71,9 +71,7 @@ def shouts_by_follower(info, follower_id: int, options): q = query_with_stat(info) reader_followed_authors = select(AuthorFollower.author).where(AuthorFollower.follower == follower_id) reader_followed_topics = select(TopicFollower.topic).where(TopicFollower.follower == follower_id) - reader_followed_shouts = select(ShoutReactionsFollower.shout).where( - ShoutReactionsFollower.follower == follower_id - ) + reader_followed_shouts = select(ShoutReactionsFollower.shout).where(ShoutReactionsFollower.follower == follower_id) followed_subquery = ( select(Shout.id) .join(ShoutAuthor, ShoutAuthor.shout == Shout.id) @@ -142,9 +140,7 @@ async def load_shouts_authored_by(_, info, slug: str, options) -> List[Shout]: q = ( query_with_stat(info) if has_field(info, "stat") - else select(Shout).filter( - and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)) - ) + else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))) ) q = q.filter(Shout.authors.any(id=author_id)) q, limit, offset = apply_options(q, options, author_id) @@ -173,9 +169,7 @@ async def load_shouts_with_topic(_, info, slug: str, options) -> List[Shout]: q = ( query_with_stat(info) if has_field(info, "stat") - else select(Shout).filter( - and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)) - ) + else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))) ) q = q.filter(Shout.topics.any(id=topic_id)) q, limit, offset = apply_options(q, options) diff --git a/resolvers/follower.py b/resolvers/follower.py index 16e8c4cf..c6dcb403 100644 --- a/resolvers/follower.py +++ b/resolvers/follower.py @@ -4,13 +4,13 @@ from graphql import GraphQLError from sqlalchemy import select from sqlalchemy.sql import and_ +from auth.orm import Author, AuthorFollower from cache.cache import ( cache_author, cache_topic, get_cached_follower_authors, get_cached_follower_topics, ) -from auth.orm import Author, AuthorFollower from orm.community import Community, CommunityFollower from orm.reaction import Reaction from orm.shout import Shout, ShoutReactionsFollower @@ -65,14 +65,14 @@ async def follow(_, info, what, slug="", entity_id=0): return {"error": f"{what.lower()} not found"} if not entity_id and entity: entity_id = entity.id - + # Если это автор, учитываем фильтрацию данных if what == "AUTHOR": # Полная версия для кэширования entity_dict = entity.dict(is_admin=True) else: - entity_dict = entity.dict() - + entity_dict = entity.dict() + logger.debug(f"entity_id: {entity_id}, entity_dict: {entity_dict}") if entity_id: @@ -87,9 +87,7 @@ async def follow(_, info, what, slug="", entity_id=0): .first() ) if existing_sub: - logger.info( - f"Пользователь {follower_id} уже подписан на {what.lower()} с ID {entity_id}" - ) + logger.info(f"Пользователь {follower_id} уже подписан на {what.lower()} с ID {entity_id}") else: logger.debug("Добавление новой записи в базу данных") sub = follower_class(follower=follower_id, **{entity_type: entity_id}) @@ -105,12 +103,12 @@ async def follow(_, info, what, slug="", entity_id=0): if get_cached_follows_method: logger.debug("Получение подписок из кэша") existing_follows = await get_cached_follows_method(follower_id) - + # Если это авторы, получаем безопасную версию if what == "AUTHOR": # Получаем ID текущего пользователя и фильтруем данные follows_filtered = [] - + for author_data in existing_follows: # Создаем объект автора для использования метода dict temp_author = Author() @@ -119,7 +117,7 @@ async def follow(_, info, what, slug="", entity_id=0): setattr(temp_author, key, value) # Добавляем отфильтрованную версию follows_filtered.append(temp_author.dict(viewer_id, False)) - + if not existing_sub: # Создаем объект автора для entity_dict temp_author = Author() @@ -132,7 +130,7 @@ async def follow(_, info, what, slug="", entity_id=0): follows = follows_filtered else: follows = [*existing_follows, entity_dict] if not existing_sub else existing_follows - + logger.debug("Обновлен список подписок") if what == "AUTHOR" and not existing_sub: @@ -214,20 +212,20 @@ async def unfollow(_, info, what, slug="", entity_id=0): await cache_method(entity.dict(is_admin=True)) else: await cache_method(entity.dict()) - + if get_cached_follows_method: logger.debug("Получение подписок из кэша") existing_follows = await get_cached_follows_method(follower_id) - + # Если это авторы, получаем безопасную версию if what == "AUTHOR": # Получаем ID текущего пользователя и фильтруем данные follows_filtered = [] - + for author_data in existing_follows: if author_data["id"] == entity_id: continue - + # Создаем объект автора для использования метода dict temp_author = Author() for key, value in author_data.items(): @@ -235,11 +233,11 @@ async def unfollow(_, info, what, slug="", entity_id=0): setattr(temp_author, key, value) # Добавляем отфильтрованную версию follows_filtered.append(temp_author.dict(viewer_id, False)) - + follows = follows_filtered else: follows = [item for item in existing_follows if item["id"] != entity_id] - + logger.debug("Обновлен список подписок") if what == "AUTHOR": diff --git a/resolvers/notifier.py b/resolvers/notifier.py index 9fe4d08f..08b7f2d4 100644 --- a/resolvers/notifier.py +++ b/resolvers/notifier.py @@ -66,9 +66,7 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[ return total, unread, notifications -def group_notification( - thread, authors=None, shout=None, reactions=None, entity="follower", action="follow" -): +def group_notification(thread, authors=None, shout=None, reactions=None, entity="follower", action="follow"): reactions = reactions or [] authors = authors or [] return { diff --git a/resolvers/proposals.py b/resolvers/proposals.py index f541732a..25218add 100644 --- a/resolvers/proposals.py +++ b/resolvers/proposals.py @@ -14,11 +14,7 @@ def handle_proposing(kind: ReactionKind, reply_to: int, shout_id: int): session.query(Reaction).filter(Reaction.id == reply_to, Reaction.shout == shout_id).first() ) - if ( - replied_reaction - and replied_reaction.kind is ReactionKind.PROPOSE.value - and replied_reaction.quote - ): + if replied_reaction and replied_reaction.kind is ReactionKind.PROPOSE.value and replied_reaction.quote: # patch all the proposals' quotes proposals = ( session.query(Reaction) diff --git a/resolvers/rating.py b/resolvers/rating.py index 397e8eac..a869c13f 100644 --- a/resolvers/rating.py +++ b/resolvers/rating.py @@ -186,9 +186,7 @@ def count_author_shouts_rating(session, author_id) -> int: def get_author_rating_old(session, author: Author): likes_count = ( - session.query(AuthorRating) - .filter(and_(AuthorRating.author == author.id, AuthorRating.plus.is_(True))) - .count() + session.query(AuthorRating).filter(and_(AuthorRating.author == author.id, AuthorRating.plus.is_(True))).count() ) dislikes_count = ( session.query(AuthorRating) diff --git a/resolvers/reaction.py b/resolvers/reaction.py index e2c4db56..0728ac4f 100644 --- a/resolvers/reaction.py +++ b/resolvers/reaction.py @@ -334,9 +334,7 @@ async def create_reaction(_, info, reaction): with local_session() as session: authors = session.query(ShoutAuthor.author).filter(ShoutAuthor.shout == shout_id).scalar() is_author = ( - bool(list(filter(lambda x: x == int(author_id), authors))) - if isinstance(authors, list) - else False + bool(list(filter(lambda x: x == int(author_id), authors))) if isinstance(authors, list) else False ) reaction_input["created_by"] = author_id kind = reaction_input.get("kind") diff --git a/resolvers/reader.py b/resolvers/reader.py index ac3fda6d..b45882eb 100644 --- a/resolvers/reader.py +++ b/resolvers/reader.py @@ -138,9 +138,7 @@ def query_with_stat(info): select( ShoutTopic.shout, json_array_builder( - json_builder( - "id", Topic.id, "title", Topic.title, "slug", Topic.slug, "is_main", ShoutTopic.main - ) + json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug, "is_main", ShoutTopic.main) ).label("topics"), ) .outerjoin(Topic, ShoutTopic.topic == Topic.id) @@ -227,7 +225,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0): "slug": a.slug, "pic": a.pic, } - + # Обработка поля updated_by if has_field(info, "updated_by"): if shout_dict.get("updated_by"): @@ -246,7 +244,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0): else: # Если updated_by не указан, устанавливаем поле в null shout_dict["updated_by"] = None - + # Обработка поля deleted_by if has_field(info, "deleted_by"): if shout_dict.get("deleted_by"): @@ -287,9 +285,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0): if hasattr(row, "main_topic"): # logger.debug(f"Raw main_topic for shout#{shout_id}: {row.main_topic}") main_topic = ( - orjson.loads(row.main_topic) - if isinstance(row.main_topic, str) - else row.main_topic + orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic ) # logger.debug(f"Parsed main_topic for shout#{shout_id}: {main_topic}") @@ -325,9 +321,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0): media_data = orjson.loads(media_data) except orjson.JSONDecodeError: media_data = [] - shout_dict["media"] = ( - [media_data] if isinstance(media_data, dict) else media_data - ) + shout_dict["media"] = [media_data] if isinstance(media_data, dict) else media_data shouts.append(shout_dict) @@ -415,9 +409,7 @@ def apply_sorting(q, options): """ order_str = options.get("order_by") if order_str in ["rating", "comments_count", "last_commented_at"]: - query_order_by = ( - desc(text(order_str)) if options.get("order_by_desc", True) else asc(text(order_str)) - ) + query_order_by = desc(text(order_str)) if options.get("order_by_desc", True) else asc(text(order_str)) q = q.distinct(text(order_str), Shout.id).order_by( # DISTINCT ON включает поле сортировки nulls_last(query_order_by), Shout.id ) @@ -513,15 +505,11 @@ async def load_shouts_unrated(_, info, options): q = select(Shout).where(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))) q = q.join(Author, Author.id == Shout.created_by) q = q.add_columns( - json_builder("id", Author.id, "name", Author.name, "slug", Author.slug, "pic", Author.pic).label( - "main_author" - ) + json_builder("id", Author.id, "name", Author.name, "slug", Author.slug, "pic", Author.pic).label("main_author") ) q = q.join(ShoutTopic, and_(ShoutTopic.shout == Shout.id, ShoutTopic.main.is_(True))) q = q.join(Topic, Topic.id == ShoutTopic.topic) - q = q.add_columns( - json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug).label("main_topic") - ) + q = q.add_columns(json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug).label("main_topic")) q = q.where(Shout.id.not_in(rated_shouts)) q = q.order_by(func.random()) diff --git a/resolvers/stat.py b/resolvers/stat.py index 35e8327c..c8bb448a 100644 --- a/resolvers/stat.py +++ b/resolvers/stat.py @@ -3,8 +3,8 @@ import asyncio from sqlalchemy import and_, distinct, func, join, select from sqlalchemy.orm import aliased -from cache.cache import cache_author from auth.orm import Author, AuthorFollower +from cache.cache import cache_author from orm.reaction import Reaction, ReactionKind from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.topic import Topic, TopicFollower @@ -177,9 +177,7 @@ def get_topic_comments_stat(topic_id: int) -> int: .subquery() ) # Запрос для суммирования количества комментариев по теме - q = select(func.coalesce(func.sum(sub_comments.c.comments_count), 0)).filter( - ShoutTopic.topic == topic_id - ) + q = select(func.coalesce(func.sum(sub_comments.c.comments_count), 0)).filter(ShoutTopic.topic == topic_id) q = q.outerjoin(sub_comments, ShoutTopic.shout == sub_comments.c.shout_id) with local_session() as session: result = session.execute(q).first() @@ -239,9 +237,7 @@ def get_author_followers_stat(author_id: int) -> int: :return: Количество уникальных подписчиков автора. """ aliased_followers = aliased(AuthorFollower) - q = select(func.count(distinct(aliased_followers.follower))).filter( - aliased_followers.author == author_id - ) + q = select(func.count(distinct(aliased_followers.follower))).filter(aliased_followers.author == author_id) with local_session() as session: result = session.execute(q).first() return result[0] if result else 0 @@ -293,9 +289,7 @@ def get_with_stat(q): stat["shouts"] = cols[1] # Статистика по публикациям stat["followers"] = cols[2] # Статистика по подписчикам if is_author: - stat["authors"] = get_author_authors_stat( - entity.id - ) # Статистика по подпискам на авторов + stat["authors"] = get_author_authors_stat(entity.id) # Статистика по подпискам на авторов stat["comments"] = get_author_comments_stat(entity.id) # Статистика по комментариям else: stat["authors"] = get_topic_authors_stat(entity.id) # Статистика по авторам темы diff --git a/resolvers/topic.py b/resolvers/topic.py index 22705ff5..8e2fa642 100644 --- a/resolvers/topic.py +++ b/resolvers/topic.py @@ -1,5 +1,6 @@ from sqlalchemy import desc, select, text +from auth.orm import Author from cache.cache import ( cache_topic, cached_query, @@ -8,9 +9,8 @@ from cache.cache import ( get_cached_topic_followers, invalidate_cache_by_prefix, ) -from auth.orm import Author -from orm.topic import Topic from orm.reaction import ReactionKind +from orm.topic import Topic from resolvers.stat import get_with_stat from services.auth import login_required from services.db import local_session diff --git a/services/auth.py b/services/auth.py index e523f85b..04ae1c86 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,16 +1,16 @@ from functools import wraps from typing import Tuple +from sqlalchemy import exc from starlette.requests import Request +from auth.internal import verify_internal_auth +from auth.orm import Author, Role from cache.cache import get_cached_author_by_id from resolvers.stat import get_with_stat -from utils.logger import root_logger as logger -from auth.internal import verify_internal_auth -from sqlalchemy import exc from services.db import local_session -from auth.orm import Author, Role from settings import SESSION_TOKEN_HEADER +from utils.logger import root_logger as logger # Список разрешенных заголовков ALLOWED_HEADERS = ["Authorization", "Content-Type"] @@ -31,21 +31,21 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]: - is_admin: bool - Флаг наличия у пользователя административных прав """ logger.debug(f"[check_auth] Проверка авторизации...") - + # Получаем заголовок авторизации token = None - + # Проверяем заголовок с учетом регистра headers_dict = dict(req.headers.items()) logger.debug(f"[check_auth] Все заголовки: {headers_dict}") - + # Ищем заголовок Authorization независимо от регистра for header_name, header_value in headers_dict.items(): if header_name.lower() == SESSION_TOKEN_HEADER.lower(): token = header_value logger.debug(f"[check_auth] Найден заголовок {header_name}: {token[:10]}...") break - + if not token: logger.debug(f"[check_auth] Токен не найден в заголовках") return "", [], False @@ -57,8 +57,10 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]: # Проверяем авторизацию внутренним механизмом logger.debug("[check_auth] Вызов verify_internal_auth...") user_id, user_roles, is_admin = await verify_internal_auth(token) - logger.debug(f"[check_auth] Результат verify_internal_auth: user_id={user_id}, roles={user_roles}, is_admin={is_admin}") - + logger.debug( + f"[check_auth] Результат verify_internal_auth: user_id={user_id}, roles={user_roles}, is_admin={is_admin}" + ) + # Если в ролях нет админа, но есть ID - проверяем в БД if user_id and not is_admin: try: @@ -71,16 +73,19 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]: else: # Проверяем наличие админских прав через БД from auth.orm import AuthorRole - admin_role = session.query(AuthorRole).filter( - AuthorRole.author == user_id_int, - AuthorRole.role.in_(["admin", "super"]) - ).first() + + admin_role = ( + session.query(AuthorRole) + .filter(AuthorRole.author == user_id_int, AuthorRole.role.in_(["admin", "super"])) + .first() + ) is_admin = admin_role is not None except Exception as e: logger.error(f"Ошибка при проверке прав администратора: {e}") - + return user_id, user_roles, is_admin + async def add_user_role(user_id: str, roles: list[str] = None): """ Добавление ролей пользователю в локальной БД. @@ -131,32 +136,32 @@ def login_required(f): info = args[1] req = info.context.get("request") - + logger.debug(f"[login_required] Проверка авторизации для запроса: {req.method} {req.url.path}") logger.debug(f"[login_required] Заголовки: {req.headers}") - + user_id, user_roles, is_admin = await check_auth(req) - + if not user_id: logger.debug(f"[login_required] Пользователь не авторизован, {dict(req)}, {info}") raise GraphQLError("Требуется авторизация") - + # Проверяем наличие роли reader - if 'reader' not in user_roles: + if "reader" not in user_roles: logger.error(f"Пользователь {user_id} не имеет роли 'reader'") raise GraphQLError("У вас нет необходимых прав для доступа") - + logger.info(f"Авторизован пользователь {user_id} с ролями: {user_roles}") info.context["roles"] = user_roles - + # Проверяем права администратора info.context["is_admin"] = is_admin - + author = await get_cached_author_by_id(user_id, get_with_stat) if not author: logger.error(f"Профиль автора не найден для пользователя {user_id}") info.context["author"] = author - + return await f(*args, **kwargs) return decorated_function @@ -177,7 +182,7 @@ def login_accepted(f): if user_id and user_roles: logger.info(f"login_accepted: Пользователь авторизован: {user_id} с ролями {user_roles}") info.context["roles"] = user_roles - + # Проверяем права администратора info.context["is_admin"] = is_admin diff --git a/services/db.py b/services/db.py index 844ee891..9c579c7d 100644 --- a/services/db.py +++ b/services/db.py @@ -200,9 +200,7 @@ class Base(declarative_base()): data[column_name] = value else: # Пропускаем атрибут, если его нет в объекте (может быть добавлен после миграции) - logger.debug( - f"Skipping missing attribute '{column_name}' for {self.__class__.__name__}" - ) + logger.debug(f"Skipping missing attribute '{column_name}' for {self.__class__.__name__}") except AttributeError as e: logger.warning(f"Attribute error for column '{column_name}': {e}") # Добавляем синтетическое поле .stat если оно существует @@ -223,9 +221,7 @@ class Base(declarative_base()): # Функция для вывода полного трейсбека при предупреждениях -def warning_with_traceback( - message: Warning | str, category, filename: str, lineno: int, file=None, line=None -): +def warning_with_traceback(message: Warning | str, category, filename: str, lineno: int, file=None, line=None): tb = traceback.format_stack() tb_str = "".join(tb) return f"{message} ({filename}, {lineno}): {category.__name__}\n{tb_str}" @@ -302,22 +298,22 @@ json_builder, json_array_builder, json_cast = get_json_builder() # Fetch all shouts, with authors preloaded # This function is used for search indexing + async def fetch_all_shouts(session=None): """Fetch all published shouts for search indexing with authors preloaded""" from orm.shout import Shout - + close_session = False if session is None: session = local_session() close_session = True - + try: # Fetch only published and non-deleted shouts with authors preloaded - query = session.query(Shout).options( - joinedload(Shout.authors) - ).filter( - Shout.published_at.is_not(None), - Shout.deleted_at.is_(None) + query = ( + session.query(Shout) + .options(joinedload(Shout.authors)) + .filter(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)) ) shouts = query.all() return shouts @@ -326,4 +322,4 @@ async def fetch_all_shouts(session=None): return [] finally: if close_session: - session.close() \ No newline at end of file + session.close() diff --git a/services/env.py b/services/env.py index e83110a2..017cf0b7 100644 --- a/services/env.py +++ b/services/env.py @@ -1,9 +1,11 @@ -from typing import Dict, List, Optional, Set -from dataclasses import dataclass import os import re +from dataclasses import dataclass from pathlib import Path +from typing import Dict, List, Optional, Set + from redis import Redis + from settings import REDIS_URL, ROOT_DIR from utils.logger import root_logger as logger @@ -31,12 +33,37 @@ class EnvManager: # Стандартные переменные окружения, которые следует исключить EXCLUDED_ENV_VARS: Set[str] = { - "PATH", "SHELL", "USER", "HOME", "PWD", "TERM", "LANG", - "PYTHONPATH", "_", "TMPDIR", "TERM_PROGRAM", "TERM_SESSION_ID", - "XPC_SERVICE_NAME", "XPC_FLAGS", "SHLVL", "SECURITYSESSIONID", - "LOGNAME", "OLDPWD", "ZSH", "PAGER", "LESS", "LC_CTYPE", "LSCOLORS", - "SSH_AUTH_SOCK", "DISPLAY", "COLORTERM", "EDITOR", "VISUAL", - "PYTHONDONTWRITEBYTECODE", "VIRTUAL_ENV", "PYTHONUNBUFFERED" + "PATH", + "SHELL", + "USER", + "HOME", + "PWD", + "TERM", + "LANG", + "PYTHONPATH", + "_", + "TMPDIR", + "TERM_PROGRAM", + "TERM_SESSION_ID", + "XPC_SERVICE_NAME", + "XPC_FLAGS", + "SHLVL", + "SECURITYSESSIONID", + "LOGNAME", + "OLDPWD", + "ZSH", + "PAGER", + "LESS", + "LC_CTYPE", + "LSCOLORS", + "SSH_AUTH_SOCK", + "DISPLAY", + "COLORTERM", + "EDITOR", + "VISUAL", + "PYTHONDONTWRITEBYTECODE", + "VIRTUAL_ENV", + "PYTHONUNBUFFERED", } # Секции для группировки переменных @@ -44,57 +71,67 @@ class EnvManager: "AUTH": { "pattern": r"^(JWT|AUTH|SESSION|OAUTH|GITHUB|GOOGLE|FACEBOOK)_", "name": "Авторизация", - "description": "Настройки системы авторизации" + "description": "Настройки системы авторизации", }, "DATABASE": { "pattern": r"^(DB|DATABASE|POSTGRES|MYSQL|SQL)_", "name": "База данных", - "description": "Настройки подключения к базам данных" + "description": "Настройки подключения к базам данных", }, "CACHE": { "pattern": r"^(REDIS|CACHE|MEMCACHED)_", "name": "Кэширование", - "description": "Настройки систем кэширования" + "description": "Настройки систем кэширования", }, "SEARCH": { "pattern": r"^(ELASTIC|SEARCH|OPENSEARCH)_", "name": "Поиск", - "description": "Настройки поисковых систем" + "description": "Настройки поисковых систем", }, "APP": { "pattern": r"^(APP|PORT|HOST|DEBUG|DOMAIN|ENVIRONMENT|ENV|FRONTEND)_", "name": "Общие настройки", - "description": "Общие настройки приложения" + "description": "Общие настройки приложения", }, "LOGGING": { "pattern": r"^(LOG|LOGGING|SENTRY|GLITCH|GLITCHTIP)_", "name": "Мониторинг", - "description": "Настройки логирования и мониторинга" + "description": "Настройки логирования и мониторинга", }, "EMAIL": { "pattern": r"^(MAIL|EMAIL|SMTP|IMAP|POP3|POST)_", "name": "Электронная почта", - "description": "Настройки отправки электронной почты" + "description": "Настройки отправки электронной почты", }, "ANALYTICS": { "pattern": r"^(GA|GOOGLE_ANALYTICS|ANALYTICS)_", "name": "Аналитика", - "description": "Настройки систем аналитики" + "description": "Настройки систем аналитики", }, } # Переменные, которые следует всегда помечать как секретные SECRET_VARS_PATTERNS = [ - r".*TOKEN.*", r".*SECRET.*", r".*PASSWORD.*", r".*KEY.*", - r".*PWD.*", r".*PASS.*", r".*CRED.*", r".*_DSN.*", - r".*JWT.*", r".*SESSION.*", r".*OAUTH.*", - r".*GITHUB.*", r".*GOOGLE.*", r".*FACEBOOK.*" + r".*TOKEN.*", + r".*SECRET.*", + r".*PASSWORD.*", + r".*KEY.*", + r".*PWD.*", + r".*PASS.*", + r".*CRED.*", + r".*_DSN.*", + r".*JWT.*", + r".*SESSION.*", + r".*OAUTH.*", + r".*GITHUB.*", + r".*GOOGLE.*", + r".*FACEBOOK.*", ] def __init__(self): self.redis = Redis.from_url(REDIS_URL) self.prefix = "env:" - self.env_file_path = os.path.join(ROOT_DIR, '.env') + self.env_file_path = os.path.join(ROOT_DIR, ".env") def get_all_variables(self) -> List[EnvSection]: """ @@ -142,15 +179,15 @@ class EnvManager: env_vars = {} if os.path.exists(self.env_file_path): try: - with open(self.env_file_path, 'r') as f: + with open(self.env_file_path, "r") as f: for line in f: line = line.strip() # Пропускаем пустые строки и комментарии - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Разделяем строку на ключ и значение - if '=' in line: - key, value = line.split('=', 1) + if "=" in line: + key, value = line.split("=", 1) key = key.strip() value = value.strip() # Удаляем кавычки, если они есть @@ -207,17 +244,17 @@ class EnvManager: """ Определяет тип переменной на основе ее значения """ - if value.lower() in ('true', 'false'): + if value.lower() in ("true", "false"): return "boolean" if value.isdigit(): return "integer" if re.match(r"^\d+\.\d+$", value): return "float" # Проверяем на JSON объект или массив - if (value.startswith('{') and value.endswith('}')) or (value.startswith('[') and value.endswith(']')): + if (value.startswith("{") and value.endswith("}")) or (value.startswith("[") and value.endswith("]")): return "json" # Проверяем на URL - if value.startswith(('http://', 'https://', 'redis://', 'postgresql://')): + if value.startswith(("http://", "https://", "redis://", "postgresql://")): return "url" return "string" @@ -233,14 +270,9 @@ class EnvManager: for key, value in variables.items(): is_secret = self._is_secret_variable(key) var_type = self._determine_variable_type(value) - - var = EnvVariable( - key=key, - value=value, - type=var_type, - is_secret=is_secret - ) - + + var = EnvVariable(key=key, value=value, type=var_type, is_secret=is_secret) + # Определяем секцию для переменной placed = False for section_id, section_config in self.SECTIONS.items(): @@ -248,7 +280,7 @@ class EnvManager: sections_dict[section_id].append(var) placed = True break - + # Если переменная не попала ни в одну секцию # if not placed: # other_variables.append(var) @@ -260,22 +292,20 @@ class EnvManager: section_config = self.SECTIONS[section_id] result.append( EnvSection( - name=section_config["name"], - description=section_config["description"], - variables=variables + name=section_config["name"], description=section_config["description"], variables=variables ) ) - + # Добавляем прочие переменные, если они есть if other_variables: result.append( EnvSection( name="Прочие переменные", description="Переменные, не вошедшие в основные категории", - variables=other_variables + variables=other_variables, ) ) - + return result def update_variable(self, key: str, value: str) -> bool: @@ -286,13 +316,13 @@ class EnvManager: # Сохраняем в Redis full_key = f"{self.prefix}{key}" self.redis.set(full_key, value) - + # Обновляем значение в .env файле self._update_dotenv_var(key, value) - + # Обновляем переменную в текущем процессе os.environ[key] = value - + return True except Exception as e: logger.error(f"Ошибка обновления переменной {key}: {e}") @@ -305,20 +335,20 @@ class EnvManager: try: # Если файл .env не существует, создаем его if not os.path.exists(self.env_file_path): - with open(self.env_file_path, 'w') as f: + with open(self.env_file_path, "w") as f: f.write(f"{key}={value}\n") return True - + # Если файл существует, читаем его содержимое lines = [] found = False - - with open(self.env_file_path, 'r') as f: + + with open(self.env_file_path, "r") as f: for line in f: - if line.strip() and not line.strip().startswith('#'): + if line.strip() and not line.strip().startswith("#"): if line.strip().startswith(f"{key}="): # Экранируем значение, если необходимо - if ' ' in value or ',' in value or '"' in value or "'" in value: + if " " in value or "," in value or '"' in value or "'" in value: escaped_value = f'"{value}"' else: escaped_value = value @@ -328,20 +358,20 @@ class EnvManager: lines.append(line) else: lines.append(line) - + # Если переменной не было в файле, добавляем ее if not found: # Экранируем значение, если необходимо - if ' ' in value or ',' in value or '"' in value or "'" in value: + if " " in value or "," in value or '"' in value or "'" in value: escaped_value = f'"{value}"' else: escaped_value = value lines.append(f"{key}={escaped_value}\n") - + # Записываем обновленный файл - with open(self.env_file_path, 'w') as f: + with open(self.env_file_path, "w") as f: f.writelines(lines) - + return True except Exception as e: logger.error(f"Ошибка обновления .env файла: {e}") @@ -358,14 +388,14 @@ class EnvManager: full_key = f"{self.prefix}{var.key}" pipe.set(full_key, var.value) pipe.execute() - + # Обновляем переменные в .env файле for var in variables: self._update_dotenv_var(var.key, var.value) - + # Обновляем переменную в текущем процессе os.environ[var.key] = var.value - + return True except Exception as e: logger.error(f"Ошибка массового обновления переменных: {e}") diff --git a/services/notify.py b/services/notify.py index cf2b2553..c42276e0 100644 --- a/services/notify.py +++ b/services/notify.py @@ -93,9 +93,7 @@ async def notify_draft(draft_data, action: str = "publish"): # Если переданы связанные атрибуты, добавим их if hasattr(draft_data, "topics") and draft_data.topics is not None: - draft_payload["topics"] = [ - {"id": t.id, "name": t.name, "slug": t.slug} for t in draft_data.topics - ] + draft_payload["topics"] = [{"id": t.id, "name": t.name, "slug": t.slug} for t in draft_data.topics] if hasattr(draft_data, "authors") and draft_data.authors is not None: draft_payload["authors"] = [ diff --git a/services/redis.py b/services/redis.py index bcbf4382..8912e6ef 100644 --- a/services/redis.py +++ b/services/redis.py @@ -30,7 +30,7 @@ class RedisService: if self._client is None: await self.connect() logger.info(f"[redis] Автоматически установлено соединение при выполнении команды {command}") - + if self._client: try: logger.debug(f"{command}") # {args[0]}") # {args} {kwargs}") @@ -55,14 +55,14 @@ class RedisService: if self._client is None: # Выбрасываем исключение, так как pipeline нельзя создать до подключения raise Exception("Redis client is not initialized. Call redis.connect() first.") - + return self._client.pipeline() async def subscribe(self, *channels): # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + async with self._client.pubsub() as pubsub: for channel in channels: await pubsub.subscribe(channel) @@ -71,7 +71,7 @@ class RedisService: async def unsubscribe(self, *channels): if self._client is None: return - + async with self._client.pubsub() as pubsub: for channel in channels: await pubsub.unsubscribe(channel) @@ -81,14 +81,14 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + await self._client.publish(channel, data) async def set(self, key, data, ex=None): # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + # Prepare the command arguments args = [key, data] @@ -104,7 +104,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + return await self.execute("get", key) async def delete(self, *keys): @@ -119,11 +119,11 @@ class RedisService: """ if not keys: return 0 - + # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + return await self._client.delete(*keys) async def hmset(self, key, mapping): @@ -137,7 +137,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + await self._client.hset(key, mapping=mapping) async def expire(self, key, seconds): @@ -151,7 +151,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + await self._client.expire(key, seconds) async def sadd(self, key, *values): @@ -165,7 +165,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + await self._client.sadd(key, *values) async def srem(self, key, *values): @@ -179,7 +179,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + await self._client.srem(key, *values) async def smembers(self, key): @@ -195,9 +195,9 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + return await self._client.smembers(key) - + async def exists(self, key): """ Проверяет, существует ли ключ в Redis. @@ -210,10 +210,10 @@ class RedisService: """ # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: - await self.connect() - + await self.connect() + return await self._client.exists(key) - + async def expire(self, key, seconds): """ Устанавливает время жизни ключа. @@ -225,7 +225,7 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + return await self._client.expire(key, seconds) async def keys(self, pattern): @@ -238,10 +238,8 @@ class RedisService: # Автоматически подключаемся к Redis, если соединение не установлено if self._client is None: await self.connect() - + return await self._client.keys(pattern) - - redis = RedisService() diff --git a/services/schema.py b/services/schema.py index ac167914..9541e95e 100644 --- a/services/schema.py +++ b/services/schema.py @@ -12,7 +12,7 @@ resolvers = [query, mutation, type_draft] def create_all_tables(): """Create all database tables in the correct order.""" - from auth.orm import Author, AuthorFollower, AuthorBookmark, AuthorRating + from auth.orm import Author, AuthorBookmark, AuthorFollower, AuthorRating from orm import community, draft, notification, reaction, shout, topic # Порядок важен - сначала таблицы без внешних ключей, затем зависимые таблицы diff --git a/services/search.py b/services/search.py index f2dd3b46..43543ca2 100644 --- a/services/search.py +++ b/services/search.py @@ -2,9 +2,11 @@ import asyncio import json import logging import os -import httpx -import time import random +import time + +import httpx + from settings import TXTAI_SERVICE_URL # Set up proper logging @@ -15,23 +17,15 @@ logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) # Configuration for search service -SEARCH_ENABLED = bool( - os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"] -) +SEARCH_ENABLED = bool(os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"]) MAX_BATCH_SIZE = int(os.environ.get("SEARCH_MAX_BATCH_SIZE", "25")) # Search cache configuration -SEARCH_CACHE_ENABLED = bool( - os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"] -) -SEARCH_CACHE_TTL_SECONDS = int( - os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300") -) # Default: 15 minutes +SEARCH_CACHE_ENABLED = bool(os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"]) +SEARCH_CACHE_TTL_SECONDS = int(os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300")) # Default: 15 minutes SEARCH_PREFETCH_SIZE = int(os.environ.get("SEARCH_PREFETCH_SIZE", "200")) -SEARCH_USE_REDIS = bool( - os.environ.get("SEARCH_USE_REDIS", "true").lower() in ["true", "1", "yes"] -) +SEARCH_USE_REDIS = bool(os.environ.get("SEARCH_USE_REDIS", "true").lower() in ["true", "1", "yes"]) search_offset = 0 @@ -68,9 +62,7 @@ class SearchCache: serialized_results, ex=self.ttl, ) - logger.info( - f"Stored {len(results)} search results for query '{query}' in Redis" - ) + logger.info(f"Stored {len(results)} search results for query '{query}' in Redis") return True except Exception as e: logger.error(f"Error storing search results in Redis: {e}") @@ -83,9 +75,7 @@ class SearchCache: # Store results and update timestamp self.cache[normalized_query] = results self.last_accessed[normalized_query] = time.time() - logger.info( - f"Cached {len(results)} search results for query '{query}' in memory" - ) + logger.info(f"Cached {len(results)} search results for query '{query}' in memory") return True async def get(self, query, limit=10, offset=0): @@ -117,14 +107,10 @@ class SearchCache: # Return paginated subset end_idx = min(offset + limit, len(all_results)) if offset >= len(all_results): - logger.warning( - f"Requested offset {offset} exceeds result count {len(all_results)}" - ) + logger.warning(f"Requested offset {offset} exceeds result count {len(all_results)}") return [] - logger.info( - f"Cache hit for '{query}': serving {offset}:{end_idx} of {len(all_results)} results" - ) + logger.info(f"Cache hit for '{query}': serving {offset}:{end_idx} of {len(all_results)} results") return all_results[offset:end_idx] async def has_query(self, query): @@ -174,11 +160,7 @@ class SearchCache: """Remove oldest entries if memory cache is full""" now = time.time() # First remove expired entries - expired_keys = [ - key - for key, last_access in self.last_accessed.items() - if now - last_access > self.ttl - ] + expired_keys = [key for key, last_access in self.last_accessed.items() if now - last_access > self.ttl] for key in expired_keys: if key in self.cache: @@ -217,9 +199,7 @@ class SearchService: if SEARCH_CACHE_ENABLED: cache_location = "Redis" if SEARCH_USE_REDIS else "Memory" - logger.info( - f"Search caching enabled using {cache_location} cache with TTL={SEARCH_CACHE_TTL_SECONDS}s" - ) + logger.info(f"Search caching enabled using {cache_location} cache with TTL={SEARCH_CACHE_TTL_SECONDS}s") async def info(self): """Return information about search service""" @@ -270,9 +250,7 @@ class SearchService: logger.info( f"Document verification complete: {bodies_missing_count} bodies missing, {titles_missing_count} titles missing" ) - logger.info( - f"Total unique missing documents: {total_missing_count} out of {len(doc_ids)} total" - ) + logger.info(f"Total unique missing documents: {total_missing_count} out of {len(doc_ids)} total") # Return in a backwards-compatible format plus the detailed breakdown return { @@ -308,9 +286,7 @@ class SearchService: # 1. Index title if available if hasattr(shout, "title") and shout.title and isinstance(shout.title, str): title_doc = {"id": str(shout.id), "title": shout.title.strip()} - indexing_tasks.append( - self.index_client.post("/index-title", json=title_doc) - ) + indexing_tasks.append(self.index_client.post("/index-title", json=title_doc)) # 2. Index body content (subtitle, lead, body) body_text_parts = [] @@ -346,9 +322,7 @@ class SearchService: body_text = body_text[:MAX_TEXT_LENGTH] body_doc = {"id": str(shout.id), "body": body_text} - indexing_tasks.append( - self.index_client.post("/index-body", json=body_doc) - ) + indexing_tasks.append(self.index_client.post("/index-body", json=body_doc)) # 3. Index authors authors = getattr(shout, "authors", []) @@ -373,30 +347,22 @@ class SearchService: if name: author_doc = {"id": author_id, "name": name, "bio": combined_bio} - indexing_tasks.append( - self.index_client.post("/index-author", json=author_doc) - ) + indexing_tasks.append(self.index_client.post("/index-author", json=author_doc)) # Run all indexing tasks in parallel if indexing_tasks: - responses = await asyncio.gather( - *indexing_tasks, return_exceptions=True - ) + responses = await asyncio.gather(*indexing_tasks, return_exceptions=True) # Check for errors in responses for i, response in enumerate(responses): if isinstance(response, Exception): logger.error(f"Error in indexing task {i}: {response}") - elif ( - hasattr(response, "status_code") and response.status_code >= 400 - ): + elif hasattr(response, "status_code") and response.status_code >= 400: logger.error( f"Error response in indexing task {i}: {response.status_code}, {await response.text()}" ) - logger.info( - f"Document {shout.id} indexed across {len(indexing_tasks)} endpoints" - ) + logger.info(f"Document {shout.id} indexed across {len(indexing_tasks)} endpoints") else: logger.warning(f"No content to index for shout {shout.id}") @@ -424,24 +390,14 @@ class SearchService: for shout in shouts: try: # 1. Process title documents - if ( - hasattr(shout, "title") - and shout.title - and isinstance(shout.title, str) - ): - title_docs.append( - {"id": str(shout.id), "title": shout.title.strip()} - ) + if hasattr(shout, "title") and shout.title and isinstance(shout.title, str): + title_docs.append({"id": str(shout.id), "title": shout.title.strip()}) # 2. Process body documents (subtitle, lead, body) body_text_parts = [] for field_name in ["subtitle", "lead", "body"]: field_value = getattr(shout, field_name, None) - if ( - field_value - and isinstance(field_value, str) - and field_value.strip() - ): + if field_value and isinstance(field_value, str) and field_value.strip(): body_text_parts.append(field_value.strip()) # Process media content if available @@ -507,9 +463,7 @@ class SearchService: } except Exception as e: - logger.error( - f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}" - ) + logger.error(f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}") total_skipped += 1 # Convert author dict to list @@ -543,9 +497,7 @@ class SearchService: logger.info(f"Indexing {len(documents)} {doc_type} documents") # Categorize documents by size - small_docs, medium_docs, large_docs = self._categorize_by_size( - documents, doc_type - ) + small_docs, medium_docs, large_docs = self._categorize_by_size(documents, doc_type) # Process each category with appropriate batch sizes batch_sizes = { @@ -561,9 +513,7 @@ class SearchService: ]: if docs: batch_size = batch_sizes[category] - await self._process_batches( - docs, batch_size, endpoint, f"{doc_type}-{category}" - ) + await self._process_batches(docs, batch_size, endpoint, f"{doc_type}-{category}") def _categorize_by_size(self, documents, doc_type): """Categorize documents by size for optimized batch processing""" @@ -599,7 +549,7 @@ class SearchService: """Process document batches with retry logic""" for i in range(0, len(documents), batch_size): batch = documents[i : i + batch_size] - batch_id = f"{batch_prefix}-{i//batch_size + 1}" + batch_id = f"{batch_prefix}-{i // batch_size + 1}" retry_count = 0 max_retries = 3 @@ -607,9 +557,7 @@ class SearchService: while not success and retry_count < max_retries: try: - response = await self.index_client.post( - endpoint, json=batch, timeout=90.0 - ) + response = await self.index_client.post(endpoint, json=batch, timeout=90.0) if response.status_code == 422: error_detail = response.json() @@ -630,13 +578,13 @@ class SearchService: batch[:mid], batch_size // 2, endpoint, - f"{batch_prefix}-{i//batch_size}-A", + f"{batch_prefix}-{i // batch_size}-A", ) await self._process_batches( batch[mid:], batch_size // 2, endpoint, - f"{batch_prefix}-{i//batch_size}-B", + f"{batch_prefix}-{i // batch_size}-B", ) else: logger.error( @@ -649,9 +597,7 @@ class SearchService: def _truncate_error_detail(self, error_detail): """Truncate error details for logging""" - truncated_detail = ( - error_detail.copy() if isinstance(error_detail, dict) else error_detail - ) + truncated_detail = error_detail.copy() if isinstance(error_detail, dict) else error_detail if ( isinstance(truncated_detail, dict) @@ -660,30 +606,22 @@ class SearchService: ): for i, item in enumerate(truncated_detail["detail"]): if isinstance(item, dict) and "input" in item: - if isinstance(item["input"], dict) and any( - k in item["input"] for k in ["documents", "text"] - ): - if "documents" in item["input"] and isinstance( - item["input"]["documents"], list - ): + if isinstance(item["input"], dict) and any(k in item["input"] for k in ["documents", "text"]): + if "documents" in item["input"] and isinstance(item["input"]["documents"], list): for j, doc in enumerate(item["input"]["documents"]): - if ( - "text" in doc - and isinstance(doc["text"], str) - and len(doc["text"]) > 100 - ): - item["input"]["documents"][j][ - "text" - ] = f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]" + if "text" in doc and isinstance(doc["text"], str) and len(doc["text"]) > 100: + item["input"]["documents"][j]["text"] = ( + f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]" + ) if ( "text" in item["input"] and isinstance(item["input"]["text"], str) and len(item["input"]["text"]) > 100 ): - item["input"][ - "text" - ] = f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]" + item["input"]["text"] = ( + f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]" + ) return truncated_detail @@ -711,9 +649,9 @@ class SearchService: search_limit = SEARCH_PREFETCH_SIZE else: search_limit = limit - + logger.info(f"Searching for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})") - + response = await self.client.post( "/search-combined", json={"text": text, "limit": search_limit}, @@ -767,9 +705,7 @@ class SearchService: logger.info( f"Searching authors for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})" ) - response = await self.client.post( - "/search-author", json={"text": text, "limit": search_limit} - ) + response = await self.client.post("/search-author", json={"text": text, "limit": search_limit}) response.raise_for_status() result = response.json() @@ -784,7 +720,7 @@ class SearchService: # Store the full prefetch batch, then page it await self.cache.store(cache_key, author_results) return await self.cache.get(cache_key, limit, offset) - + return author_results[offset : offset + limit] except Exception as e: @@ -802,9 +738,7 @@ class SearchService: result = response.json() if result.get("consistency", {}).get("status") != "ok": - null_count = result.get("consistency", {}).get( - "null_embeddings_count", 0 - ) + null_count = result.get("consistency", {}).get("null_embeddings_count", 0) if null_count > 0: logger.warning(f"Found {null_count} documents with NULL embeddings") @@ -877,14 +811,10 @@ async def initialize_search_index(shouts_data): index_status = await search_service.check_index_status() if index_status.get("status") == "inconsistent": - problem_ids = index_status.get("consistency", {}).get( - "null_embeddings_sample", [] - ) + problem_ids = index_status.get("consistency", {}).get("null_embeddings_sample", []) if problem_ids: - problem_docs = [ - shout for shout in shouts_data if str(shout.id) in problem_ids - ] + problem_docs = [shout for shout in shouts_data if str(shout.id) in problem_ids] if problem_docs: await search_service.bulk_index(problem_docs) @@ -902,9 +832,7 @@ async def initialize_search_index(shouts_data): if isinstance(media, str): try: media_json = json.loads(media) - if isinstance(media_json, dict) and ( - media_json.get("title") or media_json.get("body") - ): + if isinstance(media_json, dict) and (media_json.get("title") or media_json.get("body")): return True except Exception: return True @@ -922,13 +850,9 @@ async def initialize_search_index(shouts_data): if verification.get("status") == "error": return # Only reindex missing docs that actually have body content - missing_ids = [ - mid for mid in verification.get("missing", []) if mid in body_ids - ] + missing_ids = [mid for mid in verification.get("missing", []) if mid in body_ids] if missing_ids: - missing_docs = [ - shout for shout in shouts_with_body if str(shout.id) in missing_ids - ] + missing_docs = [shout for shout in shouts_with_body if str(shout.id) in missing_ids] await search_service.bulk_index(missing_docs) else: pass @@ -955,35 +879,35 @@ async def check_search_service(): print(f"[WARNING] Search service unavailable: {info.get('message', 'unknown reason')}") else: print(f"[INFO] Search service is available: {info}") - + # Initialize search index in the background async def initialize_search_index_background(): """ Запускает индексацию поиска в фоновом режиме с низким приоритетом. - + Эта функция: 1. Загружает все shouts из базы данных 2. Индексирует их в поисковом сервисе 3. Выполняется асинхронно, не блокируя основной поток 4. Обрабатывает возможные ошибки, не прерывая работу приложения - + Индексация запускается с задержкой после инициализации сервера, чтобы не создавать дополнительную нагрузку при запуске. """ try: print("[search] Starting background search indexing process") from services.db import fetch_all_shouts - + # Get total count first (optional) all_shouts = await fetch_all_shouts() total_count = len(all_shouts) if all_shouts else 0 print(f"[search] Fetched {total_count} shouts for background indexing") - + if not all_shouts: print("[search] No shouts found for indexing, skipping search index initialization") return - + # Start the indexing process with the fetched shouts print("[search] Beginning background search index initialization...") await initialize_search_index(all_shouts) diff --git a/services/viewed.py b/services/viewed.py index 66599783..00c0ef4e 100644 --- a/services/viewed.py +++ b/services/viewed.py @@ -80,12 +80,12 @@ class ViewedStorage: # Получаем список всех ключей migrated_views_* и находим самый последний keys = await redis.execute("KEYS", "migrated_views_*") logger.info(f" * Raw Redis result for 'KEYS migrated_views_*': {len(keys)}") - + # Декодируем байтовые строки, если есть if keys and isinstance(keys[0], bytes): - keys = [k.decode('utf-8') for k in keys] + keys = [k.decode("utf-8") for k in keys] logger.info(f" * Decoded keys: {keys}") - + if not keys: logger.warning(" * No migrated_views keys found in Redis") return @@ -93,7 +93,7 @@ class ViewedStorage: # Фильтруем только ключи timestamp формата (исключаем migrated_views_slugs) timestamp_keys = [k for k in keys if k != "migrated_views_slugs"] logger.info(f" * Timestamp keys after filtering: {timestamp_keys}") - + if not timestamp_keys: logger.warning(" * No migrated_views timestamp keys found in Redis") return @@ -243,20 +243,12 @@ class ViewedStorage: # Обновление тем и авторов с использованием вспомогательной функции for [_st, topic] in ( - session.query(ShoutTopic, Topic) - .join(Topic) - .join(Shout) - .where(Shout.slug == shout_slug) - .all() + session.query(ShoutTopic, Topic).join(Topic).join(Shout).where(Shout.slug == shout_slug).all() ): update_groups(self.shouts_by_topic, topic.slug, shout_slug) for [_st, author] in ( - session.query(ShoutAuthor, Author) - .join(Author) - .join(Shout) - .where(Shout.slug == shout_slug) - .all() + session.query(ShoutAuthor, Author).join(Author).join(Shout).where(Shout.slug == shout_slug).all() ): update_groups(self.shouts_by_author, author.slug, shout_slug) @@ -289,9 +281,7 @@ class ViewedStorage: if failed == 0: when = datetime.now(timezone.utc) + timedelta(seconds=self.period) t = format(when.astimezone().isoformat()) - logger.info( - " ⎩ next update: %s" % (t.split("T")[0] + " " + t.split("T")[1].split(".")[0]) - ) + logger.info(" ⎩ next update: %s" % (t.split("T")[0] + " " + t.split("T")[1].split(".")[0])) await asyncio.sleep(self.period) else: await asyncio.sleep(10) diff --git a/settings.py b/settings.py index 67a9643c..3bf874c8 100644 --- a/settings.py +++ b/settings.py @@ -72,4 +72,4 @@ MAILGUN_API_KEY = os.getenv("MAILGUN_API_KEY", "") MAILGUN_DOMAIN = os.getenv("MAILGUN_DOMAIN", "discours.io") -TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "none") \ No newline at end of file +TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "none") diff --git a/tests/auth/conftest.py b/tests/auth/conftest.py index ef795214..1b2cdee9 100644 --- a/tests/auth/conftest.py +++ b/tests/auth/conftest.py @@ -1,6 +1,7 @@ -import pytest from typing import Dict +import pytest + @pytest.fixture def oauth_settings() -> Dict[str, Dict[str, str]]: diff --git a/tests/auth/test_oauth.py b/tests/auth/test_oauth.py index 01adf8ba..c9cecde5 100644 --- a/tests/auth/test_oauth.py +++ b/tests/auth/test_oauth.py @@ -1,8 +1,9 @@ -import pytest from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from starlette.responses import JSONResponse, RedirectResponse -from auth.oauth import get_user_profile, oauth_login, oauth_callback +from auth.oauth import get_user_profile, oauth_callback, oauth_login # Подменяем настройки для тестов with ( diff --git a/tests/conftest.py b/tests/conftest.py index cd97d56f..85c3eb1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import asyncio + import pytest + from services.redis import redis from tests.test_config import get_test_client