linted+fmt
All checks were successful
Deploy on push / deploy (push) Successful in 6s

This commit is contained in:
Untone 2025-05-29 12:37:39 +03:00
parent d4c16658bd
commit 4070f4fcde
49 changed files with 835 additions and 983 deletions

View File

@ -2,25 +2,25 @@ from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse from starlette.responses import JSONResponse, RedirectResponse
from starlette.routing import Route from starlette.routing import Route
from auth.sessions import SessionManager
from auth.internal import verify_internal_auth from auth.internal import verify_internal_auth
from auth.orm import Author from auth.orm import Author
from auth.sessions import SessionManager
from services.db import local_session from services.db import local_session
from utils.logger import root_logger as logger
from settings import ( from settings import (
SESSION_COOKIE_NAME,
SESSION_COOKIE_HTTPONLY, SESSION_COOKIE_HTTPONLY,
SESSION_COOKIE_SECURE,
SESSION_COOKIE_SAMESITE,
SESSION_COOKIE_MAX_AGE, SESSION_COOKIE_MAX_AGE,
SESSION_COOKIE_NAME,
SESSION_COOKIE_SAMESITE,
SESSION_COOKIE_SECURE,
SESSION_TOKEN_HEADER, SESSION_TOKEN_HEADER,
) )
from utils.logger import root_logger as logger
async def logout(request: Request): async def logout(request: Request):
""" """
Выход из системы с удалением сессии и cookie. Выход из системы с удалением сессии и cookie.
Поддерживает получение токена из: Поддерживает получение токена из:
1. HTTP-only cookie 1. HTTP-only cookie
2. Заголовка Authorization 2. Заголовка Authorization
@ -30,7 +30,7 @@ async def logout(request: Request):
if SESSION_COOKIE_NAME in request.cookies: if SESSION_COOKIE_NAME in request.cookies:
token = request.cookies.get(SESSION_COOKIE_NAME) token = request.cookies.get(SESSION_COOKIE_NAME)
logger.debug(f"[auth] logout: Получен токен из cookie {SESSION_COOKIE_NAME}") logger.debug(f"[auth] logout: Получен токен из cookie {SESSION_COOKIE_NAME}")
# Если токен не найден в cookie, проверяем заголовок # Если токен не найден в cookie, проверяем заголовок
if not token: if not token:
# Сначала проверяем основной заголовок авторизации # Сначала проверяем основной заголовок авторизации
@ -42,7 +42,7 @@ async def logout(request: Request):
else: else:
token = auth_header.strip() token = auth_header.strip()
logger.debug(f"[auth] logout: Получен прямой токен из заголовка {SESSION_TOKEN_HEADER}") logger.debug(f"[auth] logout: Получен прямой токен из заголовка {SESSION_TOKEN_HEADER}")
# Если токен не найден в основном заголовке, проверяем стандартный Authorization # Если токен не найден в основном заголовке, проверяем стандартный Authorization
if not token and "Authorization" in request.headers: if not token and "Authorization" in request.headers:
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
@ -74,7 +74,7 @@ async def logout(request: Request):
key=SESSION_COOKIE_NAME, key=SESSION_COOKIE_NAME,
secure=SESSION_COOKIE_SECURE, secure=SESSION_COOKIE_SECURE,
httponly=SESSION_COOKIE_HTTPONLY, httponly=SESSION_COOKIE_HTTPONLY,
samesite=SESSION_COOKIE_SAMESITE samesite=SESSION_COOKIE_SAMESITE,
) )
logger.info("[auth] logout: Cookie успешно удалена") logger.info("[auth] logout: Cookie успешно удалена")
@ -84,22 +84,22 @@ async def logout(request: Request):
async def refresh_token(request: Request): async def refresh_token(request: Request):
""" """
Обновление токена аутентификации. Обновление токена аутентификации.
Поддерживает получение токена из: Поддерживает получение токена из:
1. HTTP-only cookie 1. HTTP-only cookie
2. Заголовка Authorization 2. Заголовка Authorization
Возвращает новый токен как в HTTP-only cookie, так и в теле ответа. Возвращает новый токен как в HTTP-only cookie, так и в теле ответа.
""" """
token = None token = None
source = None source = None
# Получаем текущий токен из cookie # Получаем текущий токен из cookie
if SESSION_COOKIE_NAME in request.cookies: if SESSION_COOKIE_NAME in request.cookies:
token = request.cookies.get(SESSION_COOKIE_NAME) token = request.cookies.get(SESSION_COOKIE_NAME)
source = "cookie" source = "cookie"
logger.debug(f"[auth] refresh_token: Токен получен из cookie {SESSION_COOKIE_NAME}") logger.debug(f"[auth] refresh_token: Токен получен из cookie {SESSION_COOKIE_NAME}")
# Если токен не найден в cookie, проверяем заголовок авторизации # Если токен не найден в cookie, проверяем заголовок авторизации
if not token: if not token:
# Проверяем основной заголовок авторизации # Проверяем основной заголовок авторизации
@ -113,7 +113,7 @@ async def refresh_token(request: Request):
token = auth_header.strip() token = auth_header.strip()
source = "header" source = "header"
logger.debug(f"[auth] refresh_token: Токен получен из заголовка {SESSION_TOKEN_HEADER} (прямой)") logger.debug(f"[auth] refresh_token: Токен получен из заголовка {SESSION_TOKEN_HEADER} (прямой)")
# Если токен не найден в основном заголовке, проверяем стандартный Authorization # Если токен не найден в основном заголовке, проверяем стандартный Authorization
if not token and "Authorization" in request.headers: if not token and "Authorization" in request.headers:
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
@ -147,9 +147,7 @@ async def refresh_token(request: Request):
if not new_token: if not new_token:
logger.error(f"[auth] refresh_token: Не удалось обновить токен для пользователя {user_id}") logger.error(f"[auth] refresh_token: Не удалось обновить токен для пользователя {user_id}")
return JSONResponse( return JSONResponse({"success": False, "error": "Не удалось обновить токен"}, status_code=500)
{"success": False, "error": "Не удалось обновить токен"}, status_code=500
)
# Создаем ответ # Создаем ответ
response = JSONResponse( response = JSONResponse(

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Set, Any from typing import Any, Dict, List, Optional, Set
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@ -1,19 +1,21 @@
from functools import wraps from functools import wraps
from typing import Callable, Any, Dict, Optional from typing import Any, Callable, Dict, Optional
from graphql import GraphQLError, GraphQLResolveInfo from graphql import GraphQLError, GraphQLResolveInfo
from sqlalchemy import exc from sqlalchemy import exc
from auth.credentials import AuthCredentials from auth.credentials import AuthCredentials
from services.db import local_session
from auth.orm import Author
from auth.exceptions import OperationNotAllowed from auth.exceptions import OperationNotAllowed
from utils.logger import root_logger as logger
from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST, SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME
from auth.sessions import SessionManager
from auth.jwtcodec import JWTCodec, InvalidToken, ExpiredToken
from auth.tokenstorage import TokenStorage
from services.redis import redis
from auth.internal import authenticate from auth.internal import authenticate
from auth.jwtcodec import ExpiredToken, InvalidToken, JWTCodec
from auth.orm import Author
from auth.sessions import SessionManager
from auth.tokenstorage import TokenStorage
from services.db import local_session
from services.redis import redis
from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST
from settings import SESSION_COOKIE_NAME, SESSION_TOKEN_HEADER
from utils.logger import root_logger as logger
ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",")
@ -21,10 +23,10 @@ ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",")
def get_safe_headers(request: Any) -> Dict[str, str]: def get_safe_headers(request: Any) -> Dict[str, str]:
""" """
Безопасно получает заголовки запроса. Безопасно получает заголовки запроса.
Args: Args:
request: Объект запроса request: Объект запроса
Returns: Returns:
Dict[str, str]: Словарь заголовков Dict[str, str]: Словарь заголовков
""" """
@ -34,12 +36,9 @@ def get_safe_headers(request: Any) -> Dict[str, str]:
if hasattr(request, "scope") and isinstance(request.scope, dict): if hasattr(request, "scope") and isinstance(request.scope, dict):
scope_headers = request.scope.get("headers", []) scope_headers = request.scope.get("headers", [])
if scope_headers: if scope_headers:
headers.update({ headers.update({k.decode("utf-8").lower(): v.decode("utf-8") for k, v in scope_headers})
k.decode("utf-8").lower(): v.decode("utf-8")
for k, v in scope_headers
})
logger.debug(f"[decorators] Получены заголовки из request.scope: {len(headers)}") logger.debug(f"[decorators] Получены заголовки из request.scope: {len(headers)}")
# Второй приоритет: метод headers() или атрибут headers # Второй приоритет: метод headers() или атрибут headers
if hasattr(request, "headers"): if hasattr(request, "headers"):
if callable(request.headers): if callable(request.headers):
@ -55,15 +54,15 @@ def get_safe_headers(request: Any) -> Dict[str, str]:
elif isinstance(h, dict): elif isinstance(h, dict):
headers.update({k.lower(): v for k, v in h.items()}) headers.update({k.lower(): v for k, v in h.items()})
logger.debug(f"[decorators] Получены заголовки из request.headers словаря: {len(headers)}") logger.debug(f"[decorators] Получены заголовки из request.headers словаря: {len(headers)}")
# Третий приоритет: атрибут _headers # Третий приоритет: атрибут _headers
if hasattr(request, "_headers") and request._headers: if hasattr(request, "_headers") and request._headers:
headers.update({k.lower(): v for k, v in request._headers.items()}) headers.update({k.lower(): v for k, v in request._headers.items()})
logger.debug(f"[decorators] Получены заголовки из request._headers: {len(headers)}") logger.debug(f"[decorators] Получены заголовки из request._headers: {len(headers)}")
except Exception as e: except Exception as e:
logger.warning(f"[decorators] Ошибка при доступе к заголовкам: {e}") logger.warning(f"[decorators] Ошибка при доступе к заголовкам: {e}")
return headers return headers
@ -72,13 +71,13 @@ def get_auth_token(request: Any) -> Optional[str]:
Извлекает токен авторизации из запроса. Извлекает токен авторизации из запроса.
Порядок проверки: Порядок проверки:
1. Проверяет auth из middleware 1. Проверяет auth из middleware
2. Проверяет auth из scope 2. Проверяет auth из scope
3. Проверяет заголовок Authorization 3. Проверяет заголовок Authorization
4. Проверяет cookie с именем auth_token 4. Проверяет cookie с именем auth_token
Args: Args:
request: Объект запроса request: Объект запроса
Returns: Returns:
Optional[str]: Токен авторизации или None Optional[str]: Токен авторизации или None
""" """
@ -100,7 +99,7 @@ def get_auth_token(request: Any) -> Optional[str]:
# 3. Проверяем заголовок Authorization # 3. Проверяем заголовок Authorization
headers = get_safe_headers(request) headers = get_safe_headers(request)
# Сначала проверяем основной заголовок авторизации # Сначала проверяем основной заголовок авторизации
auth_header = headers.get(SESSION_TOKEN_HEADER.lower(), "") auth_header = headers.get(SESSION_TOKEN_HEADER.lower(), "")
if auth_header: if auth_header:
@ -112,7 +111,7 @@ def get_auth_token(request: Any) -> Optional[str]:
token = auth_header.strip() token = auth_header.strip()
logger.debug(f"[decorators] Прямой токен получен из заголовка {SESSION_TOKEN_HEADER}: {len(token)}") logger.debug(f"[decorators] Прямой токен получен из заголовка {SESSION_TOKEN_HEADER}: {len(token)}")
return token return token
# Затем проверяем стандартный заголовок Authorization, если основной не определен # Затем проверяем стандартный заголовок Authorization, если основной не определен
if SESSION_TOKEN_HEADER.lower() != "authorization": if SESSION_TOKEN_HEADER.lower() != "authorization":
auth_header = headers.get("authorization", "") auth_header = headers.get("authorization", "")
@ -139,10 +138,10 @@ def get_auth_token(request: Any) -> Optional[str]:
async def validate_graphql_context(info: Any) -> None: async def validate_graphql_context(info: Any) -> None:
""" """
Проверяет валидность GraphQL контекста и проверяет авторизацию. Проверяет валидность GraphQL контекста и проверяет авторизацию.
Args: Args:
info: GraphQL информация о контексте info: GraphQL информация о контексте
Raises: Raises:
GraphQLError: если контекст невалиден или пользователь не авторизован GraphQLError: если контекст невалиден или пользователь не авторизован
""" """
@ -161,7 +160,7 @@ async def validate_graphql_context(info: Any) -> None:
if auth and auth.logged_in: if auth and auth.logged_in:
logger.debug(f"[decorators] Пользователь уже авторизован: {auth.author_id}") logger.debug(f"[decorators] Пользователь уже авторизован: {auth.author_id}")
return return
# Если аутентификации нет в request.auth, пробуем получить ее из scope # Если аутентификации нет в request.auth, пробуем получить ее из scope
if hasattr(request, "scope") and "auth" in request.scope: if hasattr(request, "scope") and "auth" in request.scope:
auth_cred = request.scope.get("auth") auth_cred = request.scope.get("auth")
@ -170,49 +169,45 @@ async def validate_graphql_context(info: Any) -> None:
# Устанавливаем auth в request для дальнейшего использования # Устанавливаем auth в request для дальнейшего использования
request.auth = auth_cred request.auth = auth_cred
return return
# Если авторизации нет ни в auth, ни в scope, пробуем получить и проверить токен # Если авторизации нет ни в auth, ни в scope, пробуем получить и проверить токен
token = get_auth_token(request) token = get_auth_token(request)
if not token: if not token:
# Если токен не найден, возвращаем ошибку авторизации # Если токен не найден, возвращаем ошибку авторизации
client_info = { client_info = {
"ip": getattr(request.client, "host", "unknown") if hasattr(request, "client") else "unknown", "ip": getattr(request.client, "host", "unknown") if hasattr(request, "client") else "unknown",
"headers": get_safe_headers(request) "headers": get_safe_headers(request),
} }
logger.warning(f"[decorators] Токен авторизации не найден: {client_info}") logger.warning(f"[decorators] Токен авторизации не найден: {client_info}")
raise GraphQLError("Unauthorized - please login") raise GraphQLError("Unauthorized - please login")
# Используем единый механизм проверки токена из auth.internal # Используем единый механизм проверки токена из auth.internal
auth_state = await authenticate(request) auth_state = await authenticate(request)
if not auth_state.logged_in: if not auth_state.logged_in:
error_msg = auth_state.error or "Invalid or expired token" error_msg = auth_state.error or "Invalid or expired token"
logger.warning(f"[decorators] Недействительный токен: {error_msg}") logger.warning(f"[decorators] Недействительный токен: {error_msg}")
raise GraphQLError(f"Unauthorized - {error_msg}") raise GraphQLError(f"Unauthorized - {error_msg}")
# Если все проверки пройдены, создаем AuthCredentials и устанавливаем в request.auth # Если все проверки пройдены, создаем AuthCredentials и устанавливаем в request.auth
with local_session() as session: with local_session() as session:
try: try:
author = session.query(Author).filter(Author.id == auth_state.author_id).one() author = session.query(Author).filter(Author.id == auth_state.author_id).one()
# Получаем разрешения из ролей # Получаем разрешения из ролей
scopes = author.get_permissions() scopes = author.get_permissions()
# Создаем объект авторизации # Создаем объект авторизации
auth_cred = AuthCredentials( auth_cred = AuthCredentials(
author_id=author.id, author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=auth_state.token
scopes=scopes,
logged_in=True,
email=author.email,
token=auth_state.token
) )
# Устанавливаем auth в request # Устанавливаем auth в request
request.auth = auth_cred request.auth = auth_cred
logger.debug(f"[decorators] Токен успешно проверен и установлен для пользователя {auth_state.author_id}") logger.debug(f"[decorators] Токен успешно проверен и установлен для пользователя {auth_state.author_id}")
except exc.NoResultFound: except exc.NoResultFound:
logger.error(f"[decorators] Пользователь с ID {auth_state.author_id} не найден в базе данных") logger.error(f"[decorators] Пользователь с ID {auth_state.author_id} не найден в базе данных")
raise GraphQLError("Unauthorized - user not found") raise GraphQLError("Unauthorized - user not found")
return return
@ -229,18 +224,19 @@ def admin_auth_required(resolver: Callable) -> Callable:
Raises: Raises:
GraphQLError: если пользователь не авторизован или не имеет доступа администратора GraphQLError: если пользователь не авторизован или не имеет доступа администратора
Example: Example:
>>> @admin_auth_required >>> @admin_auth_required
... async def admin_resolver(root, info, **kwargs): ... async def admin_resolver(root, info, **kwargs):
... return "Admin data" ... return "Admin data"
""" """
@wraps(resolver) @wraps(resolver)
async def wrapper(root: Any = None, info: Any = None, **kwargs): async def wrapper(root: Any = None, info: Any = None, **kwargs):
try: try:
# Проверяем авторизацию пользователя # Проверяем авторизацию пользователя
await validate_graphql_context(info) await validate_graphql_context(info)
# Получаем объект авторизации # Получаем объект авторизации
auth = info.context["request"].auth auth = info.context["request"].auth
if not auth or not auth.logged_in: if not auth or not auth.logged_in:
@ -255,22 +251,24 @@ def admin_auth_required(resolver: Callable) -> Callable:
if not author_id: if not author_id:
logger.error(f"[admin_auth_required] ID автора не определен: {auth}") logger.error(f"[admin_auth_required] ID автора не определен: {auth}")
raise GraphQLError("Unauthorized - invalid user ID") raise GraphQLError("Unauthorized - invalid user ID")
author = session.query(Author).filter(Author.id == author_id).one() author = session.query(Author).filter(Author.id == author_id).one()
# Проверяем, является ли пользователь администратором # Проверяем, является ли пользователь администратором
if author.email in ADMIN_EMAILS: if author.email in ADMIN_EMAILS:
logger.info(f"Admin access granted for {author.email} (ID: {author.id})") logger.info(f"Admin access granted for {author.email} (ID: {author.id})")
return await resolver(root, info, **kwargs) return await resolver(root, info, **kwargs)
# Проверяем роли пользователя # Проверяем роли пользователя
admin_roles = ['admin', 'super'] admin_roles = ["admin", "super"]
user_roles = [role.id for role in author.roles] if author.roles else [] user_roles = [role.id for role in author.roles] if author.roles else []
if any(role in admin_roles for role in user_roles): if any(role in admin_roles for role in user_roles):
logger.info(f"Admin access granted for {author.email} (ID: {author.id}) with role: {user_roles}") logger.info(
f"Admin access granted for {author.email} (ID: {author.id}) with role: {user_roles}"
)
return await resolver(root, info, **kwargs) return await resolver(root, info, **kwargs)
logger.warning(f"Admin access denied for {author.email} (ID: {author.id}). Roles: {user_roles}") logger.warning(f"Admin access denied for {author.email} (ID: {author.id}). Roles: {user_roles}")
raise GraphQLError("Unauthorized - not an admin") raise GraphQLError("Unauthorized - not an admin")
except exc.NoResultFound: except exc.NoResultFound:
@ -301,7 +299,7 @@ def permission_required(resource: str, operation: str, func):
async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs):
# Сначала проверяем авторизацию # Сначала проверяем авторизацию
await validate_graphql_context(info) await validate_graphql_context(info)
# Получаем объект авторизации # Получаем объект авторизации
logger.debug(f"[permission_required] Контекст: {info.context}") logger.debug(f"[permission_required] Контекст: {info.context}")
auth = info.context["request"].auth auth = info.context["request"].auth
@ -324,21 +322,27 @@ def permission_required(resource: str, operation: str, func):
if author.email in ADMIN_EMAILS: if author.email in ADMIN_EMAILS:
logger.debug(f"[permission_required] Администратор {author.email} имеет все разрешения") logger.debug(f"[permission_required] Администратор {author.email} имеет все разрешения")
return await func(parent, info, *args, **kwargs) return await func(parent, info, *args, **kwargs)
# Проверяем роли пользователя # Проверяем роли пользователя
admin_roles = ['admin', 'super'] admin_roles = ["admin", "super"]
user_roles = [role.id for role in author.roles] if author.roles else [] user_roles = [role.id for role in author.roles] if author.roles else []
if any(role in admin_roles for role in user_roles): if any(role in admin_roles for role in user_roles):
logger.debug(f"[permission_required] Пользователь с ролью администратора {author.email} имеет все разрешения") logger.debug(
f"[permission_required] Пользователь с ролью администратора {author.email} имеет все разрешения"
)
return await func(parent, info, *args, **kwargs) return await func(parent, info, *args, **kwargs)
# Проверяем разрешение # Проверяем разрешение
if not author.has_permission(resource, operation): if not author.has_permission(resource, operation):
logger.warning(f"[permission_required] У пользователя {author.email} нет разрешения {operation} на {resource}") logger.warning(
f"[permission_required] У пользователя {author.email} нет разрешения {operation} на {resource}"
)
raise OperationNotAllowed(f"No permission for {operation} on {resource}") raise OperationNotAllowed(f"No permission for {operation} on {resource}")
logger.debug(f"[permission_required] Пользователь {author.email} имеет разрешение {operation} на {resource}") logger.debug(
f"[permission_required] Пользователь {author.email} имеет разрешение {operation} на {resource}"
)
return await func(parent, info, *args, **kwargs) return await func(parent, info, *args, **kwargs)
except exc.NoResultFound: except exc.NoResultFound:
logger.error(f"[permission_required] Пользователь с ID {auth.author_id} не найден в базе данных") logger.error(f"[permission_required] Пользователь с ID {auth.author_id} не найден в базе данных")
@ -349,14 +353,15 @@ def permission_required(resource: str, operation: str, func):
def login_accepted(func): def login_accepted(func):
""" """
Декоратор для резолверов, которые могут работать как с авторизованными, Декоратор для резолверов, которые могут работать как с авторизованными,
так и с неавторизованными пользователями. так и с неавторизованными пользователями.
Добавляет информацию о пользователе в контекст, если пользователь авторизован. Добавляет информацию о пользователе в контекст, если пользователь авторизован.
Args: Args:
func: Декорируемая функция func: Декорируемая функция
""" """
@wraps(func) @wraps(func)
async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs):
try: try:
@ -366,10 +371,10 @@ def login_accepted(func):
except GraphQLError: except GraphQLError:
# Игнорируем ошибку авторизации # Игнорируем ошибку авторизации
pass pass
# Получаем объект авторизации # Получаем объект авторизации
auth = getattr(info.context["request"], "auth", None) auth = getattr(info.context["request"], "auth", None)
if auth and auth.logged_in: if auth and auth.logged_in:
# Если пользователь авторизован, добавляем информацию о нем в контекст # Если пользователь авторизован, добавляем информацию о нем в контекст
with local_session() as session: with local_session() as session:

View File

@ -1,46 +1,48 @@
from ariadne.asgi.handlers import GraphQLHTTPHandler from ariadne.asgi.handlers import GraphQLHTTPHandler
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import JSONResponse, Response
from auth.middleware import auth_middleware from auth.middleware import auth_middleware
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
class EnhancedGraphQLHTTPHandler(GraphQLHTTPHandler): class EnhancedGraphQLHTTPHandler(GraphQLHTTPHandler):
""" """
Улучшенный GraphQL HTTP обработчик с поддержкой cookie и авторизации. Улучшенный GraphQL HTTP обработчик с поддержкой cookie и авторизации.
Расширяет стандартный GraphQLHTTPHandler для: Расширяет стандартный GraphQLHTTPHandler для:
1. Создания расширенного контекста запроса с авторизационными данными 1. Создания расширенного контекста запроса с авторизационными данными
2. Корректной обработки ответов с cookie и headers 2. Корректной обработки ответов с cookie и headers
3. Интеграции с AuthMiddleware 3. Интеграции с AuthMiddleware
""" """
async def get_context_for_request(self, request: Request, data: dict) -> dict: async def get_context_for_request(self, request: Request, data: dict) -> dict:
""" """
Расширяем контекст для GraphQL запросов. Расширяем контекст для GraphQL запросов.
Добавляет к стандартному контексту: Добавляет к стандартному контексту:
- Объект response для установки cookie - Объект response для установки cookie
- Интеграцию с AuthMiddleware - Интеграцию с AuthMiddleware
- Расширения для управления авторизацией - Расширения для управления авторизацией
Args: Args:
request: Starlette Request объект request: Starlette Request объект
data: данные запроса data: данные запроса
Returns: Returns:
dict: контекст с дополнительными данными для авторизации и cookie dict: контекст с дополнительными данными для авторизации и cookie
""" """
# Получаем стандартный контекст от базового класса # Получаем стандартный контекст от базового класса
context = await super().get_context_for_request(request, data) context = await super().get_context_for_request(request, data)
# Создаем объект ответа для установки cookie # Создаем объект ответа для установки cookie
response = JSONResponse({}) response = JSONResponse({})
context["response"] = response context["response"] = response
# Интегрируем с AuthMiddleware # Интегрируем с AuthMiddleware
auth_middleware.set_context(context) auth_middleware.set_context(context)
context["extensions"] = auth_middleware context["extensions"] = auth_middleware
# Добавляем данные авторизации только если они доступны # Добавляем данные авторизации только если они доступны
# Без проверки hasattr, так как это вызывает ошибку до обработки AuthenticationMiddleware # Без проверки hasattr, так как это вызывает ошибку до обработки AuthenticationMiddleware
if hasattr(request, "auth") and request.auth: if hasattr(request, "auth") and request.auth:
@ -48,7 +50,7 @@ class EnhancedGraphQLHTTPHandler(GraphQLHTTPHandler):
context["auth"] = request.auth context["auth"] = request.auth
# Безопасно логируем информацию о типе объекта auth # Безопасно логируем информацию о типе объекта auth
logger.debug(f"[graphql] Добавлены данные авторизации в контекст: {type(request.auth).__name__}") logger.debug(f"[graphql] Добавлены данные авторизации в контекст: {type(request.auth).__name__}")
logger.debug(f"[graphql] Подготовлен расширенный контекст для запроса") logger.debug(f"[graphql] Подготовлен расширенный контекст для запроса")
return context return context

View File

@ -1,13 +1,12 @@
from binascii import hexlify from binascii import hexlify
from hashlib import sha256 from hashlib import sha256
from typing import Any, Dict, TypeVar, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, TypeVar
from passlib.hash import bcrypt from passlib.hash import bcrypt
from auth.exceptions import ExpiredToken, InvalidToken, InvalidPassword from auth.exceptions import ExpiredToken, InvalidPassword, InvalidToken
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from auth.tokenstorage import TokenStorage from auth.tokenstorage import TokenStorage
from services.db import local_session from services.db import local_session
# Для типизации # Для типизации
@ -86,9 +85,7 @@ class Identity:
# Проверим исходный пароль в orm_author # Проверим исходный пароль в orm_author
if not orm_author.password: if not orm_author.password:
logger.warning( logger.warning(f"[auth.identity] Пароль в исходном объекте автора пуст: email={orm_author.email}")
f"[auth.identity] Пароль в исходном объекте автора пуст: email={orm_author.email}"
)
raise InvalidPassword("Пароль не установлен для данного пользователя") raise InvalidPassword("Пароль не установлен для данного пользователя")
# Проверяем пароль напрямую, не используя dict() # Проверяем пароль напрямую, не используя dict()

View File

@ -1,22 +1,22 @@
from typing import Optional, Tuple
import time import time
from typing import Any from typing import Any, Optional, Tuple
from sqlalchemy.orm import exc from sqlalchemy.orm import exc
from starlette.authentication import AuthenticationBackend, BaseUser, UnauthenticatedUser from starlette.authentication import AuthenticationBackend, BaseUser, UnauthenticatedUser
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from auth.credentials import AuthCredentials from auth.credentials import AuthCredentials
from auth.exceptions import ExpiredToken, InvalidToken
from auth.jwtcodec import JWTCodec
from auth.orm import Author from auth.orm import Author
from auth.sessions import SessionManager from auth.sessions import SessionManager
from services.db import local_session
from settings import SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME, ADMIN_EMAILS as ADMIN_EMAILS_LIST
from utils.logger import root_logger as logger
from auth.jwtcodec import JWTCodec
from auth.exceptions import ExpiredToken, InvalidToken
from auth.state import AuthState from auth.state import AuthState
from auth.tokenstorage import TokenStorage from auth.tokenstorage import TokenStorage
from services.db import local_session
from services.redis import redis from services.redis import redis
from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST
from settings import SESSION_COOKIE_NAME, SESSION_TOKEN_HEADER
from utils.logger import root_logger as logger
ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",")
@ -24,13 +24,9 @@ ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",")
class AuthenticatedUser(BaseUser): class AuthenticatedUser(BaseUser):
"""Аутентифицированный пользователь для Starlette""" """Аутентифицированный пользователь для Starlette"""
def __init__(self, def __init__(
user_id: str, self, user_id: str, username: str = "", roles: list = None, permissions: dict = None, token: str = None
username: str = "", ):
roles: list = None,
permissions: dict = None,
token: str = None
):
self.user_id = user_id self.user_id = user_id
self.username = username self.username = username
self.roles = roles or [] self.roles = roles or []
@ -56,17 +52,17 @@ class InternalAuthentication(AuthenticationBackend):
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
""" """
Аутентифицирует пользователя по токену из заголовка или cookie. Аутентифицирует пользователя по токену из заголовка или cookie.
Порядок поиска токена: Порядок поиска токена:
1. Проверяем заголовок SESSION_TOKEN_HEADER (может быть установлен middleware) 1. Проверяем заголовок SESSION_TOKEN_HEADER (может быть установлен middleware)
2. Проверяем scope/auth в request, куда middleware мог сохранить токен 2. Проверяем scope/auth в request, куда middleware мог сохранить токен
3. Проверяем cookie 3. Проверяем cookie
Возвращает: Возвращает:
tuple: (AuthCredentials, BaseUser) tuple: (AuthCredentials, BaseUser)
""" """
token = None token = None
# 1. Проверяем заголовок # 1. Проверяем заголовок
if SESSION_TOKEN_HEADER in request.headers: if SESSION_TOKEN_HEADER in request.headers:
token_header = request.headers.get(SESSION_TOKEN_HEADER) token_header = request.headers.get(SESSION_TOKEN_HEADER)
@ -77,19 +73,19 @@ class InternalAuthentication(AuthenticationBackend):
else: else:
token = token_header.strip() token = token_header.strip()
logger.debug(f"[auth.authenticate] Извлечен прямой токен из заголовка {SESSION_TOKEN_HEADER}") logger.debug(f"[auth.authenticate] Извлечен прямой токен из заголовка {SESSION_TOKEN_HEADER}")
# 2. Проверяем scope/auth, который мог быть установлен middleware # 2. Проверяем scope/auth, который мог быть установлен middleware
if not token and hasattr(request, "scope") and "auth" in request.scope: if not token and hasattr(request, "scope") and "auth" in request.scope:
auth_data = request.scope.get("auth", {}) auth_data = request.scope.get("auth", {})
if isinstance(auth_data, dict) and "token" in auth_data: if isinstance(auth_data, dict) and "token" in auth_data:
token = auth_data["token"] token = auth_data["token"]
logger.debug(f"[auth.authenticate] Извлечен токен из request.scope['auth']") logger.debug(f"[auth.authenticate] Извлечен токен из request.scope['auth']")
# 3. Проверяем cookie # 3. Проверяем cookie
if not token and hasattr(request, "cookies") and SESSION_COOKIE_NAME in request.cookies: if not token and hasattr(request, "cookies") and SESSION_COOKIE_NAME in request.cookies:
token = request.cookies.get(SESSION_COOKIE_NAME) token = request.cookies.get(SESSION_COOKIE_NAME)
logger.debug(f"[auth.authenticate] Извлечен токен из cookie {SESSION_COOKIE_NAME}") logger.debug(f"[auth.authenticate] Извлечен токен из cookie {SESSION_COOKIE_NAME}")
# Если токен не найден, возвращаем неаутентифицированного пользователя # Если токен не найден, возвращаем неаутентифицированного пользователя
if not token: if not token:
logger.debug("[auth.authenticate] Токен не найден") logger.debug("[auth.authenticate] Токен не найден")
@ -112,9 +108,7 @@ class InternalAuthentication(AuthenticationBackend):
if author.is_locked(): if author.is_locked():
logger.debug(f"[auth.authenticate] Аккаунт заблокирован: {author.id}") logger.debug(f"[auth.authenticate] Аккаунт заблокирован: {author.id}")
return AuthCredentials( return AuthCredentials(scopes={}, error_message="Account is locked"), UnauthenticatedUser()
scopes={}, error_message="Account is locked"
), UnauthenticatedUser()
# Получаем разрешения из ролей # Получаем разрешения из ролей
scopes = author.get_permissions() scopes = author.get_permissions()
@ -128,11 +122,7 @@ class InternalAuthentication(AuthenticationBackend):
# Создаем объекты авторизации с сохранением токена # Создаем объекты авторизации с сохранением токена
credentials = AuthCredentials( credentials = AuthCredentials(
author_id=author.id, author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=token
scopes=scopes,
logged_in=True,
email=author.email,
token=token
) )
user = AuthenticatedUser( user = AuthenticatedUser(
@ -140,7 +130,7 @@ class InternalAuthentication(AuthenticationBackend):
username=author.slug or author.email or "", username=author.slug or author.email or "",
roles=roles, roles=roles,
permissions=scopes, permissions=scopes,
token=token token=token,
) )
logger.debug(f"[auth.authenticate] Успешная аутентификация: {author.email}") logger.debug(f"[auth.authenticate] Успешная аутентификация: {author.email}")
@ -163,7 +153,7 @@ async def verify_internal_auth(token: str) -> Tuple[str, list, bool]:
tuple: (user_id, roles, is_admin) tuple: (user_id, roles, is_admin)
""" """
logger.debug(f"[verify_internal_auth] Проверка токена: {token[:10]}...") logger.debug(f"[verify_internal_auth] Проверка токена: {token[:10]}...")
# Обработка формата "Bearer <token>" (если токен не был обработан ранее) # Обработка формата "Bearer <token>" (если токен не был обработан ранее)
if token and token.startswith("Bearer "): if token and token.startswith("Bearer "):
token = token.replace("Bearer ", "", 1).strip() token = token.replace("Bearer ", "", 1).strip()
@ -188,11 +178,13 @@ async def verify_internal_auth(token: str) -> Tuple[str, list, bool]:
# Получаем роли # Получаем роли
roles = [role.id for role in author.roles] roles = [role.id for role in author.roles]
logger.debug(f"[verify_internal_auth] Роли пользователя: {roles}") logger.debug(f"[verify_internal_auth] Роли пользователя: {roles}")
# Определяем, является ли пользователь администратором # Определяем, является ли пользователь администратором
is_admin = any(role in ['admin', 'super'] for role in roles) or author.email in ADMIN_EMAILS is_admin = any(role in ["admin", "super"] for role in roles) or author.email in ADMIN_EMAILS
logger.debug(f"[verify_internal_auth] Пользователь {author.id} {'является' if is_admin else 'не является'} администратором") logger.debug(
f"[verify_internal_auth] Пользователь {author.id} {'является' if is_admin else 'не является'} администратором"
)
return str(author.id), roles, is_admin return str(author.id), roles, is_admin
except exc.NoResultFound: except exc.NoResultFound:
logger.warning(f"[verify_internal_auth] Пользователь с ID {payload.user_id} не найден в БД или не активен") logger.warning(f"[verify_internal_auth] Пользователь с ID {payload.user_id} не найден в БД или не активен")
@ -257,7 +249,7 @@ async def authenticate(request: Any) -> AuthState:
headers = dict(request.headers()) headers = dict(request.headers())
else: else:
headers = dict(request.headers) headers = dict(request.headers)
auth_header = headers.get(SESSION_TOKEN_HEADER, "") auth_header = headers.get(SESSION_TOKEN_HEADER, "")
if auth_header and auth_header.startswith("Bearer "): if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:].strip() token = auth_header[7:].strip()
@ -285,13 +277,13 @@ async def authenticate(request: Any) -> AuthState:
logger.warning(f"[auth.authenticate] Токен не валиден: не найдена сессия") logger.warning(f"[auth.authenticate] Токен не валиден: не найдена сессия")
state.error = "Invalid or expired token" state.error = "Invalid or expired token"
return state return state
# Создаем успешное состояние авторизации # Создаем успешное состояние авторизации
state.logged_in = True state.logged_in = True
state.author_id = payload.user_id state.author_id = payload.user_id
state.token = token state.token = token
state.username = payload.username state.username = payload.username
# Если запрос имеет атрибут auth, устанавливаем в него авторизационные данные # Если запрос имеет атрибут auth, устанавливаем в него авторизационные данные
if hasattr(request, "auth") or hasattr(request, "__setattr__"): if hasattr(request, "auth") or hasattr(request, "__setattr__"):
try: try:
@ -301,22 +293,20 @@ async def authenticate(request: Any) -> AuthState:
if author: if author:
# Получаем разрешения из ролей # Получаем разрешения из ролей
scopes = author.get_permissions() scopes = author.get_permissions()
# Создаем объект авторизации # Создаем объект авторизации
auth_cred = AuthCredentials( auth_cred = AuthCredentials(
author_id=author.id, author_id=author.id, scopes=scopes, logged_in=True, email=author.email, token=token
scopes=scopes,
logged_in=True,
email=author.email,
token=token
) )
# Устанавливаем auth в request # Устанавливаем auth в request
setattr(request, "auth", auth_cred) setattr(request, "auth", auth_cred)
logger.debug(f"[auth.authenticate] Авторизационные данные установлены в request.auth для {payload.user_id}") logger.debug(
f"[auth.authenticate] Авторизационные данные установлены в request.auth для {payload.user_id}"
)
except Exception as e: except Exception as e:
logger.error(f"[auth.authenticate] Ошибка при установке auth в request: {e}") logger.error(f"[auth.authenticate] Ошибка при установке auth в request: {e}")
logger.info(f"[auth.authenticate] Успешная аутентификация пользователя {state.author_id}") logger.info(f"[auth.authenticate] Успешная аутентификация пользователя {state.author_id}")
return state return state

View File

@ -1,12 +1,13 @@
from datetime import datetime, timezone, timedelta from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt import jwt
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
from utils.logger import root_logger as logger
from auth.exceptions import ExpiredToken, InvalidToken from auth.exceptions import ExpiredToken, InvalidToken
from settings import JWT_ALGORITHM, JWT_SECRET_KEY from settings import JWT_ALGORITHM, JWT_SECRET_KEY
from utils.logger import root_logger as logger
class TokenPayload(BaseModel): class TokenPayload(BaseModel):
user_id: str user_id: str
@ -28,14 +29,14 @@ class JWTCodec:
# Для объектов с атрибутами # Для объектов с атрибутами
user_id = str(getattr(user, "id", "")) user_id = str(getattr(user, "id", ""))
username = getattr(user, "slug", "") or getattr(user, "email", "") or getattr(user, "phone", "") or "" username = getattr(user, "slug", "") or getattr(user, "email", "") or getattr(user, "phone", "") or ""
logger.debug(f"[JWTCodec.encode] Кодирование токена для user_id={user_id}, username={username}") logger.debug(f"[JWTCodec.encode] Кодирование токена для user_id={user_id}, username={username}")
# Если время истечения не указано, установим срок годности на 30 дней # Если время истечения не указано, установим срок годности на 30 дней
if exp is None: if exp is None:
exp = datetime.now(tz=timezone.utc) + timedelta(days=30) exp = datetime.now(tz=timezone.utc) + timedelta(days=30)
logger.debug(f"[JWTCodec.encode] Время истечения не указано, устанавливаем срок: {exp}") logger.debug(f"[JWTCodec.encode] Время истечения не указано, устанавливаем срок: {exp}")
# Важно: убедимся, что exp всегда является либо datetime, либо целым числом от timestamp # Важно: убедимся, что exp всегда является либо datetime, либо целым числом от timestamp
if isinstance(exp, datetime): if isinstance(exp, datetime):
# Преобразуем datetime в timestamp чтобы гарантировать правильный формат # Преобразуем datetime в timestamp чтобы гарантировать правильный формат
@ -44,7 +45,7 @@ class JWTCodec:
# Если передано что-то другое, установим значение по умолчанию # Если передано что-то другое, установим значение по умолчанию
logger.warning(f"[JWTCodec.encode] Некорректный формат exp: {exp}, используем значение по умолчанию") logger.warning(f"[JWTCodec.encode] Некорректный формат exp: {exp}, используем значение по умолчанию")
exp_timestamp = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp()) exp_timestamp = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp())
payload = { payload = {
"user_id": user_id, "user_id": user_id,
"username": username, "username": username,
@ -52,9 +53,9 @@ class JWTCodec:
"iat": datetime.now(tz=timezone.utc), "iat": datetime.now(tz=timezone.utc),
"iss": "discours", "iss": "discours",
} }
logger.debug(f"[JWTCodec.encode] Сформирован payload: {payload}") logger.debug(f"[JWTCodec.encode] Сформирован payload: {payload}")
try: try:
token = jwt.encode(payload, JWT_SECRET_KEY, JWT_ALGORITHM) token = jwt.encode(payload, JWT_SECRET_KEY, JWT_ALGORITHM)
logger.debug(f"[JWTCodec.encode] Токен успешно создан, длина: {len(token) if token else 0}") logger.debug(f"[JWTCodec.encode] Токен успешно создан, длина: {len(token) if token else 0}")
@ -66,11 +67,11 @@ class JWTCodec:
@staticmethod @staticmethod
def decode(token: str, verify_exp: bool = True): def decode(token: str, verify_exp: bool = True):
logger.debug(f"[JWTCodec.decode] Начало декодирования токена длиной {len(token) if token else 0}") logger.debug(f"[JWTCodec.decode] Начало декодирования токена длиной {len(token) if token else 0}")
if not token: if not token:
logger.error("[JWTCodec.decode] Пустой токен") logger.error("[JWTCodec.decode] Пустой токен")
return None return None
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token,
@ -83,21 +84,23 @@ class JWTCodec:
issuer="discours", issuer="discours",
) )
logger.debug(f"[JWTCodec.decode] Декодирован payload: {payload}") logger.debug(f"[JWTCodec.decode] Декодирован payload: {payload}")
# Убедимся, что exp существует (добавим обработку если exp отсутствует) # Убедимся, что exp существует (добавим обработку если exp отсутствует)
if "exp" not in payload: if "exp" not in payload:
logger.warning(f"[JWTCodec.decode] В токене отсутствует поле exp") logger.warning(f"[JWTCodec.decode] В токене отсутствует поле exp")
# Добавим exp по умолчанию, чтобы избежать ошибки при создании TokenPayload # Добавим exp по умолчанию, чтобы избежать ошибки при создании TokenPayload
payload["exp"] = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp()) payload["exp"] = int((datetime.now(tz=timezone.utc) + timedelta(days=30)).timestamp())
try: try:
r = TokenPayload(**payload) r = TokenPayload(**payload)
logger.debug(f"[JWTCodec.decode] Создан объект TokenPayload: user_id={r.user_id}, username={r.username}") logger.debug(
f"[JWTCodec.decode] Создан объект TokenPayload: user_id={r.user_id}, username={r.username}"
)
return r return r
except Exception as e: except Exception as e:
logger.error(f"[JWTCodec.decode] Ошибка при создании TokenPayload: {e}") logger.error(f"[JWTCodec.decode] Ошибка при создании TokenPayload: {e}")
return None return None
except jwt.InvalidIssuedAtError: except jwt.InvalidIssuedAtError:
logger.error("[JWTCodec.decode] Недействительное время выпуска токена") logger.error("[JWTCodec.decode] Недействительное время выпуска токена")
return None return None

View File

@ -1,19 +1,29 @@
""" """
Middleware для обработки авторизации в GraphQL запросах Middleware для обработки авторизации в GraphQL запросах
""" """
from typing import Any, Dict from typing import Any, Dict
from starlette.datastructures import Headers
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.datastructures import Headers from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Scope, Receive, Send
from settings import (
SESSION_COOKIE_HTTPONLY,
SESSION_COOKIE_MAX_AGE,
SESSION_COOKIE_NAME,
SESSION_COOKIE_SAMESITE,
SESSION_COOKIE_SECURE,
SESSION_TOKEN_HEADER,
)
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
from settings import SESSION_COOKIE_HTTPONLY, SESSION_COOKIE_MAX_AGE, SESSION_COOKIE_SAMESITE, SESSION_COOKIE_SECURE, SESSION_TOKEN_HEADER, SESSION_COOKIE_NAME
class AuthMiddleware: class AuthMiddleware:
""" """
Универсальный middleware для обработки авторизации и управления cookies. Универсальный middleware для обработки авторизации и управления cookies.
Основные функции: Основные функции:
1. Извлечение Bearer токена из заголовка Authorization или cookie 1. Извлечение Bearer токена из заголовка Authorization или cookie
2. Добавление токена в заголовки запроса для обработки AuthenticationMiddleware 2. Добавление токена в заголовки запроса для обработки AuthenticationMiddleware
@ -23,7 +33,7 @@ class AuthMiddleware:
def __init__(self, app: ASGIApp): def __init__(self, app: ASGIApp):
self.app = app self.app = app
self._context = None self._context = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
"""Обработка ASGI запроса""" """Обработка ASGI запроса"""
if scope["type"] != "http": if scope["type"] != "http":
@ -93,33 +103,29 @@ class AuthMiddleware:
scope["headers"] = new_headers scope["headers"] = new_headers
# Также добавляем информацию о типе аутентификации для дальнейшего использования # Также добавляем информацию о типе аутентификации для дальнейшего использования
scope["auth"] = { scope["auth"] = {"type": "bearer", "token": token, "source": token_source}
"type": "bearer",
"token": token,
"source": token_source
}
logger.debug(f"[middleware] Токен добавлен в scope для аутентификации из источника: {token_source}") logger.debug(f"[middleware] Токен добавлен в scope для аутентификации из источника: {token_source}")
else: else:
logger.debug(f"[middleware] Токен не найден ни в заголовке, ни в cookie") logger.debug(f"[middleware] Токен не найден ни в заголовке, ни в cookie")
await self.app(scope, receive, send) await self.app(scope, receive, send)
def set_context(self, context): def set_context(self, context):
"""Сохраняет ссылку на контекст GraphQL запроса""" """Сохраняет ссылку на контекст GraphQL запроса"""
self._context = context self._context = context
logger.debug(f"[middleware] Установлен контекст GraphQL: {bool(context)}") logger.debug(f"[middleware] Установлен контекст GraphQL: {bool(context)}")
def set_cookie(self, key, value, **options): def set_cookie(self, key, value, **options):
""" """
Устанавливает cookie в ответе Устанавливает cookie в ответе
Args: Args:
key: Имя cookie key: Имя cookie
value: Значение cookie value: Значение cookie
**options: Дополнительные параметры (httponly, secure, max_age, etc.) **options: Дополнительные параметры (httponly, secure, max_age, etc.)
""" """
success = False success = False
# Способ 1: Через response # Способ 1: Через response
if self._context and "response" in self._context and hasattr(self._context["response"], "set_cookie"): if self._context and "response" in self._context and hasattr(self._context["response"], "set_cookie"):
try: try:
@ -128,7 +134,7 @@ class AuthMiddleware:
success = True success = True
except Exception as e: except Exception as e:
logger.error(f"[middleware] Ошибка при установке cookie {key} через response: {str(e)}") logger.error(f"[middleware] Ошибка при установке cookie {key} через response: {str(e)}")
# Способ 2: Через собственный response в контексте # Способ 2: Через собственный response в контексте
if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "set_cookie"): if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "set_cookie"):
try: try:
@ -137,20 +143,20 @@ class AuthMiddleware:
success = True success = True
except Exception as e: except Exception as e:
logger.error(f"[middleware] Ошибка при установке cookie {key} через _response: {str(e)}") logger.error(f"[middleware] Ошибка при установке cookie {key} через _response: {str(e)}")
if not success: if not success:
logger.error(f"[middleware] Не удалось установить cookie {key}: объекты response недоступны") logger.error(f"[middleware] Не удалось установить cookie {key}: объекты response недоступны")
def delete_cookie(self, key, **options): def delete_cookie(self, key, **options):
""" """
Удаляет cookie из ответа Удаляет cookie из ответа
Args: Args:
key: Имя cookie для удаления key: Имя cookie для удаления
**options: Дополнительные параметры **options: Дополнительные параметры
""" """
success = False success = False
# Способ 1: Через response # Способ 1: Через response
if self._context and "response" in self._context and hasattr(self._context["response"], "delete_cookie"): if self._context and "response" in self._context and hasattr(self._context["response"], "delete_cookie"):
try: try:
@ -159,7 +165,7 @@ class AuthMiddleware:
success = True success = True
except Exception as e: except Exception as e:
logger.error(f"[middleware] Ошибка при удалении cookie {key} через response: {str(e)}") logger.error(f"[middleware] Ошибка при удалении cookie {key} через response: {str(e)}")
# Способ 2: Через собственный response в контексте # Способ 2: Через собственный response в контексте
if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "delete_cookie"): if not success and hasattr(self, "_response") and self._response and hasattr(self._response, "delete_cookie"):
try: try:
@ -168,7 +174,7 @@ class AuthMiddleware:
success = True success = True
except Exception as e: except Exception as e:
logger.error(f"[middleware] Ошибка при удалении cookie {key} через _response: {str(e)}") logger.error(f"[middleware] Ошибка при удалении cookie {key} через _response: {str(e)}")
if not success: if not success:
logger.error(f"[middleware] Не удалось удалить cookie {key}: объекты response недоступны") logger.error(f"[middleware] Не удалось удалить cookie {key}: объекты response недоступны")
@ -180,38 +186,41 @@ class AuthMiddleware:
try: try:
# Получаем доступ к контексту запроса # Получаем доступ к контексту запроса
context = info.context context = info.context
# Сохраняем ссылку на контекст # Сохраняем ссылку на контекст
self.set_context(context) self.set_context(context)
# Добавляем себя как объект, содержащий утилитные методы # Добавляем себя как объект, содержащий утилитные методы
context["extensions"] = self context["extensions"] = self
# Проверяем наличие response в контексте # Проверяем наличие response в контексте
if "response" not in context or not context["response"]: if "response" not in context or not context["response"]:
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
context["response"] = JSONResponse({}) context["response"] = JSONResponse({})
logger.debug("[middleware] Создан новый response объект в контексте GraphQL") logger.debug("[middleware] Создан новый response объект в контексте GraphQL")
logger.debug(f"[middleware] GraphQL resolve: контекст подготовлен, добавлены расширения для работы с cookie") logger.debug(
f"[middleware] GraphQL resolve: контекст подготовлен, добавлены расширения для работы с cookie"
)
return await next(root, info, *args, **kwargs) return await next(root, info, *args, **kwargs)
except Exception as e: except Exception as e:
logger.error(f"[AuthMiddleware] Ошибка в GraphQL resolve: {str(e)}") logger.error(f"[AuthMiddleware] Ошибка в GraphQL resolve: {str(e)}")
raise raise
async def process_result(self, request: Request, result: Any) -> Response: async def process_result(self, request: Request, result: Any) -> Response:
""" """
Обрабатывает результат GraphQL запроса, поддерживая установку cookie Обрабатывает результат GraphQL запроса, поддерживая установку cookie
Args: Args:
request: Starlette Request объект request: Starlette Request объект
result: результат GraphQL запроса (dict или Response) result: результат GraphQL запроса (dict или Response)
Returns: Returns:
Response: HTTP-ответ с результатом и cookie (если необходимо) Response: HTTP-ответ с результатом и cookie (если необходимо)
""" """
# Проверяем, является ли result уже объектом Response # Проверяем, является ли result уже объектом Response
if isinstance(result, Response): if isinstance(result, Response):
response = result response = result
@ -220,19 +229,20 @@ class AuthMiddleware:
if isinstance(result, JSONResponse): if isinstance(result, JSONResponse):
try: try:
import json import json
result_data = json.loads(result.body.decode('utf-8'))
result_data = json.loads(result.body.decode("utf-8"))
except Exception as e: except Exception as e:
logger.error(f"[process_result] Не удалось извлечь данные из JSONResponse: {str(e)}") logger.error(f"[process_result] Не удалось извлечь данные из JSONResponse: {str(e)}")
else: else:
response = JSONResponse(result) response = JSONResponse(result)
result_data = result result_data = result
# Проверяем, был ли токен в запросе или ответе # Проверяем, был ли токен в запросе или ответе
if request.method == "POST": if request.method == "POST":
try: try:
data = await request.json() data = await request.json()
op_name = data.get("operationName", "").lower() op_name = data.get("operationName", "").lower()
# Если это операция логина или обновления токена, и в ответе есть токен # Если это операция логина или обновления токена, и в ответе есть токен
if op_name in ["login", "refreshtoken"]: if op_name in ["login", "refreshtoken"]:
token = None token = None
@ -243,32 +253,35 @@ class AuthMiddleware:
op_result = data_obj.get(op_name, {}) op_result = data_obj.get(op_name, {})
if isinstance(op_result, dict) and "token" in op_result: if isinstance(op_result, dict) and "token" in op_result:
token = op_result.get("token") token = op_result.get("token")
if token: if token:
# Устанавливаем cookie с токеном # Устанавливаем cookie с токеном
response.set_cookie( response.set_cookie(
key=SESSION_COOKIE_NAME, key=SESSION_COOKIE_NAME,
value=token, value=token,
httponly=SESSION_COOKIE_HTTPONLY, httponly=SESSION_COOKIE_HTTPONLY,
secure=SESSION_COOKIE_SECURE, secure=SESSION_COOKIE_SECURE,
samesite=SESSION_COOKIE_SAMESITE, samesite=SESSION_COOKIE_SAMESITE,
max_age=SESSION_COOKIE_MAX_AGE, max_age=SESSION_COOKIE_MAX_AGE,
) )
logger.debug(f"[graphql_handler] Установлена cookie {SESSION_COOKIE_NAME} для операции {op_name}") logger.debug(
f"[graphql_handler] Установлена cookie {SESSION_COOKIE_NAME} для операции {op_name}"
)
# Если это операция logout, удаляем cookie # Если это операция logout, удаляем cookie
elif op_name == "logout": elif op_name == "logout":
response.delete_cookie( response.delete_cookie(
key=SESSION_COOKIE_NAME, key=SESSION_COOKIE_NAME,
secure=SESSION_COOKIE_SECURE, secure=SESSION_COOKIE_SECURE,
httponly=SESSION_COOKIE_HTTPONLY, httponly=SESSION_COOKIE_HTTPONLY,
samesite=SESSION_COOKIE_SAMESITE samesite=SESSION_COOKIE_SAMESITE,
) )
logger.debug(f"[graphql_handler] Удалена cookie {SESSION_COOKIE_NAME} для операции {op_name}") logger.debug(f"[graphql_handler] Удалена cookie {SESSION_COOKIE_NAME} для операции {op_name}")
except Exception as e: except Exception as e:
logger.error(f"[process_result] Ошибка при обработке POST запроса: {str(e)}") logger.error(f"[process_result] Ошибка при обработке POST запроса: {str(e)}")
return response return response
# Создаем единый экземпляр AuthMiddleware для использования с GraphQL # Создаем единый экземпляр AuthMiddleware для использования с GraphQL
auth_middleware = AuthMiddleware(lambda scope, receive, send: None) auth_middleware = AuthMiddleware(lambda scope, receive, send: None)

View File

@ -1,11 +1,12 @@
import time
from secrets import token_urlsafe
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oauth2.rfc7636 import create_s256_code_challenge from authlib.oauth2.rfc7636 import create_s256_code_challenge
from starlette.responses import RedirectResponse, JSONResponse from starlette.responses import JSONResponse, RedirectResponse
from secrets import token_urlsafe
import time
from auth.tokenstorage import TokenStorage
from auth.orm import Author from auth.orm import Author
from auth.tokenstorage import TokenStorage
from services.db import local_session from services.db import local_session
from settings import FRONTEND_URL, OAUTH_CLIENTS from settings import FRONTEND_URL, OAUTH_CLIENTS
@ -129,9 +130,7 @@ async def oauth_callback(request):
return JSONResponse({"error": "Provider not configured"}, status_code=400) return JSONResponse({"error": "Provider not configured"}, status_code=400)
# Получаем токен с PKCE verifier # Получаем токен с PKCE verifier
token = await client.authorize_access_token( token = await client.authorize_access_token(request, code_verifier=request.session.get("code_verifier"))
request, code_verifier=request.session.get("code_verifier")
)
# Получаем профиль пользователя # Получаем профиль пользователя
profile = await get_user_profile(provider, client, token) profile = await get_user_profile(provider, client, token)

View File

@ -1,5 +1,6 @@
import time import time
from typing import Dict, Set from typing import Dict, Set
from sqlalchemy import JSON, Boolean, Column, ForeignKey, Index, Integer, String from sqlalchemy import JSON, Boolean, Column, ForeignKey, Index, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -180,7 +181,7 @@ class Author(Base):
# ) # )
# Список защищенных полей, которые видны только владельцу и администраторам # Список защищенных полей, которые видны только владельцу и администраторам
_protected_fields = ['email', 'password', 'provider_access_token', 'provider_refresh_token'] _protected_fields = ["email", "password", "provider_access_token", "provider_refresh_token"]
@property @property
def is_authenticated(self) -> bool: def is_authenticated(self) -> bool:
@ -241,27 +242,27 @@ class Author(Base):
def dict(self, access=False) -> Dict: def dict(self, access=False) -> Dict:
""" """
Сериализует объект Author в словарь с учетом прав доступа. Сериализует объект Author в словарь с учетом прав доступа.
Args: Args:
access (bool, optional): Флаг, указывающий, доступны ли защищенные поля access (bool, optional): Флаг, указывающий, доступны ли защищенные поля
Returns: Returns:
dict: Словарь с атрибутами Author, отфильтрованный по правам доступа dict: Словарь с атрибутами Author, отфильтрованный по правам доступа
""" """
# Получаем все атрибуты объекта # Получаем все атрибуты объекта
result = {c.name: getattr(self, c.name) for c in self.__table__.columns} result = {c.name: getattr(self, c.name) for c in self.__table__.columns}
# Добавляем роли как список идентификаторов и названий # Добавляем роли как список идентификаторов и названий
if hasattr(self, 'roles'): if hasattr(self, "roles"):
result['roles'] = [] result["roles"] = []
for role in self.roles: for role in self.roles:
if isinstance(role, dict): if isinstance(role, dict):
result['roles'].append(role.get('id')) result["roles"].append(role.get("id"))
# скрываем защищенные поля # скрываем защищенные поля
if not access: if not access:
for field in self._protected_fields: for field in self._protected_fields:
if field in result: if field in result:
result[field] = None result[field] = None
return result return result

View File

@ -9,9 +9,9 @@ from typing import List, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from auth.orm import Author, Role, RolePermission, Permission from auth.orm import Author, Permission, Role, RolePermission
from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST
from orm.community import Community, CommunityFollower, CommunityRole from orm.community import Community, CommunityFollower, CommunityRole
from settings import ADMIN_EMAILS as ADMIN_EMAILS_LIST
ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",") ADMIN_EMAILS = ADMIN_EMAILS_LIST.split(",")
@ -110,9 +110,7 @@ class ContextualPermissionCheck:
return has_permission return has_permission
@staticmethod @staticmethod
def get_user_community_roles( def get_user_community_roles(session: Session, author_id: int, community_slug: str) -> List[CommunityRole]:
session: Session, author_id: int, community_slug: str
) -> List[CommunityRole]:
""" """
Получает список ролей пользователя в сообществе. Получает список ролей пользователя в сообществе.

View File

@ -1,9 +1,10 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, List from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from services.redis import redis
from auth.jwtcodec import JWTCodec, TokenPayload from auth.jwtcodec import JWTCodec, TokenPayload
from services.redis import redis
from settings import SESSION_TOKEN_LIFE_SPAN from settings import SESSION_TOKEN_LIFE_SPAN
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
@ -28,11 +29,11 @@ class SessionManager:
def _make_session_key(user_id: str, token: str) -> str: def _make_session_key(user_id: str, token: str) -> str:
""" """
Создаёт ключ для сессии в Redis. Создаёт ключ для сессии в Redis.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
token: JWT токен сессии token: JWT токен сессии
Returns: Returns:
str: Ключ сессии str: Ключ сессии
""" """
@ -44,10 +45,10 @@ class SessionManager:
def _make_user_sessions_key(user_id: str) -> str: def _make_user_sessions_key(user_id: str) -> str:
""" """
Создаёт ключ для списка активных сессий пользователя. Создаёт ключ для списка активных сессий пользователя.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
Returns: Returns:
str: Ключ списка сессий str: Ключ списка сессий
""" """
@ -57,12 +58,12 @@ class SessionManager:
async def create_session(cls, user_id: str, username: str, device_info: Optional[dict] = None) -> str: async def create_session(cls, user_id: str, username: str, device_info: Optional[dict] = None) -> str:
""" """
Создаёт новую сессию. Создаёт новую сессию.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
username: Имя пользователя username: Имя пользователя
device_info: Информация об устройстве (опционально) device_info: Информация об устройстве (опционально)
Returns: Returns:
str: JWT токен сессии str: JWT токен сессии
""" """
@ -96,37 +97,37 @@ class SessionManager:
# Устанавливаем время жизни ключей (30 дней) # Устанавливаем время жизни ключей (30 дней)
pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.expire(session_key, 30 * 24 * 60 * 60)
pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60) pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60)
# Также создаем ключ в формате, совместимом с TokenStorage для обратной совместимости # Также создаем ключ в формате, совместимом с TokenStorage для обратной совместимости
token_key = f"{user_id}-{username}-{token}" token_key = f"{user_id}-{username}-{token}"
pipeline.hset(token_key, mapping={"user_id": user_id, "username": username}) pipeline.hset(token_key, mapping={"user_id": user_id, "username": username})
pipeline.expire(token_key, 30 * 24 * 60 * 60) pipeline.expire(token_key, 30 * 24 * 60 * 60)
result = await pipeline.execute() result = await pipeline.execute()
logger.info(f"[SessionManager.create_session] Сессия успешно создана для пользователя {user_id}") logger.info(f"[SessionManager.create_session] Сессия успешно создана для пользователя {user_id}")
return token return token
@classmethod @classmethod
async def verify_session(cls, token: str) -> Optional[TokenPayload]: async def verify_session(cls, token: str) -> Optional[TokenPayload]:
""" """
Проверяет сессию по токену. Проверяет сессию по токену.
Args: Args:
token: JWT токен token: JWT токен
Returns: Returns:
Optional[TokenPayload]: Данные токена или None, если сессия недействительна Optional[TokenPayload]: Данные токена или None, если сессия недействительна
""" """
logger.debug(f"[SessionManager.verify_session] Проверка сессии для токена: {token[:20]}...") logger.debug(f"[SessionManager.verify_session] Проверка сессии для токена: {token[:20]}...")
# Декодируем токен для получения payload # Декодируем токен для получения payload
try: try:
payload = JWTCodec.decode(token) payload = JWTCodec.decode(token)
if not payload: if not payload:
logger.error("[SessionManager.verify_session] Не удалось декодировать токен") logger.error("[SessionManager.verify_session] Не удалось декодировать токен")
return None return None
logger.debug(f"[SessionManager.verify_session] Успешно декодирован токен, user_id={payload.user_id}") logger.debug(f"[SessionManager.verify_session] Успешно декодирован токен, user_id={payload.user_id}")
except Exception as e: except Exception as e:
logger.error(f"[SessionManager.verify_session] Ошибка при декодировании токена: {str(e)}") logger.error(f"[SessionManager.verify_session] Ошибка при декодировании токена: {str(e)}")
@ -134,69 +135,71 @@ class SessionManager:
# Получаем данные из payload # Получаем данные из payload
user_id = payload.user_id user_id = payload.user_id
# Формируем ключ сессии # Формируем ключ сессии
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
logger.debug(f"[SessionManager.verify_session] Сформирован ключ сессии: {session_key}") logger.debug(f"[SessionManager.verify_session] Сформирован ключ сессии: {session_key}")
# Проверяем существование сессии в Redis # Проверяем существование сессии в Redis
exists = await redis.exists(session_key) exists = await redis.exists(session_key)
if not exists: if not exists:
logger.warning(f"[SessionManager.verify_session] Сессия не найдена: {user_id}. Ключ: {session_key}") logger.warning(f"[SessionManager.verify_session] Сессия не найдена: {user_id}. Ключ: {session_key}")
# Проверяем также ключ в старом формате TokenStorage для обратной совместимости # Проверяем также ключ в старом формате TokenStorage для обратной совместимости
token_key = f"{user_id}-{payload.username}-{token}" token_key = f"{user_id}-{payload.username}-{token}"
old_format_exists = await redis.exists(token_key) old_format_exists = await redis.exists(token_key)
if old_format_exists: if old_format_exists:
logger.info(f"[SessionManager.verify_session] Найдена сессия в старом формате: {token_key}") logger.info(f"[SessionManager.verify_session] Найдена сессия в старом формате: {token_key}")
# Миграция: создаем запись в новом формате # Миграция: создаем запись в новом формате
session_data = { session_data = {
"user_id": user_id, "user_id": user_id,
"username": payload.username, "username": payload.username,
} }
# Копируем сессию в новый формат # Копируем сессию в новый формат
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.hset(session_key, mapping=session_data) pipeline.hset(session_key, mapping=session_data)
pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.expire(session_key, 30 * 24 * 60 * 60)
pipeline.sadd(cls._make_user_sessions_key(user_id), token) pipeline.sadd(cls._make_user_sessions_key(user_id), token)
await pipeline.execute() await pipeline.execute()
logger.info(f"[SessionManager.verify_session] Сессия мигрирована в новый формат: {session_key}") logger.info(f"[SessionManager.verify_session] Сессия мигрирована в новый формат: {session_key}")
return payload return payload
# Если сессия не найдена ни в новом, ни в старом формате, проверяем все ключи в Redis # Если сессия не найдена ни в новом, ни в старом формате, проверяем все ключи в Redis
keys = await redis.keys("session:*") keys = await redis.keys("session:*")
logger.debug(f"[SessionManager.verify_session] Все ключи сессий в Redis: {keys}") logger.debug(f"[SessionManager.verify_session] Все ключи сессий в Redis: {keys}")
# Проверяем, можно ли доверять токену напрямую # Проверяем, можно ли доверять токену напрямую
# Если токен валидный и не истек, мы можем доверять ему даже без записи в Redis # Если токен валидный и не истек, мы можем доверять ему даже без записи в Redis
if payload and payload.exp and payload.exp > datetime.now(tz=timezone.utc): if payload and payload.exp and payload.exp > datetime.now(tz=timezone.utc):
logger.info(f"[SessionManager.verify_session] Токен валиден по JWT, создаем сессию для {user_id}") logger.info(f"[SessionManager.verify_session] Токен валиден по JWT, создаем сессию для {user_id}")
# Создаем сессию на основе валидного токена # Создаем сессию на основе валидного токена
session_data = { session_data = {
"user_id": user_id, "user_id": user_id,
"username": payload.username, "username": payload.username,
"created_at": datetime.now(tz=timezone.utc).isoformat(), "created_at": datetime.now(tz=timezone.utc).isoformat(),
"expires_at": payload.exp.isoformat() if isinstance(payload.exp, datetime) else datetime.fromtimestamp(payload.exp, tz=timezone.utc).isoformat(), "expires_at": payload.exp.isoformat()
if isinstance(payload.exp, datetime)
else datetime.fromtimestamp(payload.exp, tz=timezone.utc).isoformat(),
} }
# Сохраняем сессию в Redis # Сохраняем сессию в Redis
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.hset(session_key, mapping=session_data) pipeline.hset(session_key, mapping=session_data)
pipeline.expire(session_key, 30 * 24 * 60 * 60) pipeline.expire(session_key, 30 * 24 * 60 * 60)
pipeline.sadd(cls._make_user_sessions_key(user_id), token) pipeline.sadd(cls._make_user_sessions_key(user_id), token)
await pipeline.execute() await pipeline.execute()
logger.info(f"[SessionManager.verify_session] Создана новая сессия для валидного токена: {session_key}") logger.info(f"[SessionManager.verify_session] Создана новая сессия для валидного токена: {session_key}")
return payload return payload
# Если сессии нет, возвращаем None # Если сессии нет, возвращаем None
return None return None
# Если сессия найдена, возвращаем payload # Если сессия найдена, возвращаем payload
logger.debug(f"[SessionManager.verify_session] Сессия найдена для пользователя {user_id}") logger.debug(f"[SessionManager.verify_session] Сессия найдена для пользователя {user_id}")
return payload return payload
@ -205,89 +208,89 @@ class SessionManager:
async def get_user_sessions(cls, user_id: str) -> List[Dict[str, Any]]: async def get_user_sessions(cls, user_id: str) -> List[Dict[str, Any]]:
""" """
Получает список активных сессий пользователя. Получает список активных сессий пользователя.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
Returns: Returns:
List[Dict[str, Any]]: Список сессий List[Dict[str, Any]]: Список сессий
""" """
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
tokens = await redis.smembers(user_sessions_key) tokens = await redis.smembers(user_sessions_key)
sessions = [] sessions = []
for token in tokens: for token in tokens:
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
session_data = await redis.hgetall(session_key) session_data = await redis.hgetall(session_key)
if session_data: if session_data:
session = dict(session_data) session = dict(session_data)
session["token"] = token session["token"] = token
sessions.append(session) sessions.append(session)
return sessions return sessions
@classmethod @classmethod
async def delete_session(cls, user_id: str, token: str) -> bool: async def delete_session(cls, user_id: str, token: str) -> bool:
""" """
Удаляет сессию. Удаляет сессию.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
token: JWT токен token: JWT токен
Returns: Returns:
bool: True, если сессия успешно удалена bool: True, если сессия успешно удалена
""" """
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
# Удаляем данные сессии и токен из списка сессий пользователя # Удаляем данные сессии и токен из списка сессий пользователя
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.delete(session_key) pipeline.delete(session_key)
pipeline.srem(user_sessions_key, token) pipeline.srem(user_sessions_key, token)
# Также удаляем ключ в формате TokenStorage для полной очистки # Также удаляем ключ в формате TokenStorage для полной очистки
token_payload = JWTCodec.decode(token) token_payload = JWTCodec.decode(token)
if token_payload: if token_payload:
token_key = f"{user_id}-{token_payload.username}-{token}" token_key = f"{user_id}-{token_payload.username}-{token}"
pipeline.delete(token_key) pipeline.delete(token_key)
results = await pipeline.execute() results = await pipeline.execute()
return bool(results[0]) or bool(results[1]) return bool(results[0]) or bool(results[1])
@classmethod @classmethod
async def delete_all_sessions(cls, user_id: str) -> int: async def delete_all_sessions(cls, user_id: str) -> int:
""" """
Удаляет все сессии пользователя. Удаляет все сессии пользователя.
Args: Args:
user_id: ID пользователя user_id: ID пользователя
Returns: Returns:
int: Количество удаленных сессий int: Количество удаленных сессий
""" """
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
tokens = await redis.smembers(user_sessions_key) tokens = await redis.smembers(user_sessions_key)
count = 0 count = 0
for token in tokens: for token in tokens:
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
# Удаляем данные сессии # Удаляем данные сессии
deleted = await redis.delete(session_key) deleted = await redis.delete(session_key)
count += deleted count += deleted
# Также удаляем ключ в формате TokenStorage # Также удаляем ключ в формате TokenStorage
token_payload = JWTCodec.decode(token) token_payload = JWTCodec.decode(token)
if token_payload: if token_payload:
token_key = f"{user_id}-{token_payload.username}-{token}" token_key = f"{user_id}-{token_payload.username}-{token}"
await redis.delete(token_key) await redis.delete(token_key)
# Очищаем список токенов # Очищаем список токенов
await redis.delete(user_sessions_key) await redis.delete(user_sessions_key)
return count return count
@classmethod @classmethod

View File

@ -2,12 +2,13 @@
Классы состояния авторизации Классы состояния авторизации
""" """
class AuthState: class AuthState:
""" """
Класс для хранения информации о состоянии авторизации пользователя. Класс для хранения информации о состоянии авторизации пользователя.
Используется в аутентификационных middleware и функциях. Используется в аутентификационных middleware и функциях.
""" """
def __init__(self): def __init__(self):
self.logged_in = False self.logged_in = False
self.author_id = None self.author_id = None
@ -16,7 +17,7 @@ class AuthState:
self.is_admin = False self.is_admin = False
self.is_editor = False self.is_editor = False
self.error = None self.error = None
def __bool__(self): def __bool__(self):
"""Возвращает True если пользователь авторизован""" """Возвращает True если пользователь авторизован"""
return self.logged_in return self.logged_in

View File

@ -1,7 +1,7 @@
from datetime import datetime, timedelta, timezone
import json import json
import time import time
from typing import Dict, Any, Optional, Tuple, List from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from auth.validations import AuthInput from auth.validations import AuthInput
@ -81,7 +81,7 @@ class TokenStorage:
# Формируем ключи для Redis # Формируем ключи для Redis
token_key = cls._make_token_key(user_id, username, token) token_key = cls._make_token_key(user_id, username, token)
logger.debug(f"[TokenStorage.create_session] Сформированы ключи: token_key={token_key}") logger.debug(f"[TokenStorage.create_session] Сформированы ключи: token_key={token_key}")
# Формируем ключи в новом формате SessionManager для совместимости # Формируем ключи в новом формате SessionManager для совместимости
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
@ -91,25 +91,25 @@ class TokenStorage:
"user_id": user_id, "user_id": user_id,
"username": username, "username": username,
"created_at": time.time(), "created_at": time.time(),
"expires_at": time.time() + 30 * 24 * 60 * 60 # 30 дней "expires_at": time.time() + 30 * 24 * 60 * 60, # 30 дней
} }
if device_info: if device_info:
token_data.update(device_info) token_data.update(device_info)
logger.debug(f"[TokenStorage.create_session] Сформированы данные сессии: {token_data}") logger.debug(f"[TokenStorage.create_session] Сформированы данные сессии: {token_data}")
# Сохраняем в Redis старый формат # Сохраняем в Redis старый формат
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.hset(token_key, mapping=token_data) pipeline.hset(token_key, mapping=token_data)
pipeline.expire(token_key, 30 * 24 * 60 * 60) # 30 дней pipeline.expire(token_key, 30 * 24 * 60 * 60) # 30 дней
# Также сохраняем в новом формате SessionManager для обеспечения совместимости # Также сохраняем в новом формате SessionManager для обеспечения совместимости
pipeline.hset(session_key, mapping=token_data) pipeline.hset(session_key, mapping=token_data)
pipeline.expire(session_key, 30 * 24 * 60 * 60) # 30 дней pipeline.expire(session_key, 30 * 24 * 60 * 60) # 30 дней
pipeline.sadd(user_sessions_key, token) pipeline.sadd(user_sessions_key, token)
pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60) # 30 дней pipeline.expire(user_sessions_key, 30 * 24 * 60 * 60) # 30 дней
results = await pipeline.execute() results = await pipeline.execute()
logger.info(f"[TokenStorage.create_session] Сессия успешно создана для пользователя {user_id}") logger.info(f"[TokenStorage.create_session] Сессия успешно создана для пользователя {user_id}")
@ -146,39 +146,39 @@ class TokenStorage:
if not payload: if not payload:
logger.warning(f"[TokenStorage.validate_token] Токен не валиден (не удалось декодировать)") logger.warning(f"[TokenStorage.validate_token] Токен не валиден (не удалось декодировать)")
return False, None return False, None
user_id = payload.user_id user_id = payload.user_id
username = payload.username username = payload.username
# Формируем ключи для Redis в обоих форматах # Формируем ключи для Redis в обоих форматах
token_key = cls._make_token_key(user_id, username, token) token_key = cls._make_token_key(user_id, username, token)
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
# Проверяем в обоих форматах для совместимости # Проверяем в обоих форматах для совместимости
old_exists = await redis.exists(token_key) old_exists = await redis.exists(token_key)
new_exists = await redis.exists(session_key) new_exists = await redis.exists(session_key)
if old_exists or new_exists: if old_exists or new_exists:
logger.info(f"[TokenStorage.validate_token] Токен валиден для пользователя {user_id}") logger.info(f"[TokenStorage.validate_token] Токен валиден для пользователя {user_id}")
# Получаем данные токена из актуального хранилища # Получаем данные токена из актуального хранилища
if new_exists: if new_exists:
token_data = await redis.hgetall(session_key) token_data = await redis.hgetall(session_key)
else: else:
token_data = await redis.hgetall(token_key) token_data = await redis.hgetall(token_key)
# Если найден только в старом формате, создаем запись в новом формате # Если найден только в старом формате, создаем запись в новом формате
if not new_exists: if not new_exists:
logger.info(f"[TokenStorage.validate_token] Миграция токена в новый формат: {session_key}") logger.info(f"[TokenStorage.validate_token] Миграция токена в новый формат: {session_key}")
await redis.hset(session_key, mapping=token_data) await redis.hset(session_key, mapping=token_data)
await redis.expire(session_key, 30 * 24 * 60 * 60) await redis.expire(session_key, 30 * 24 * 60 * 60)
await redis.sadd(cls._make_user_sessions_key(user_id), token) await redis.sadd(cls._make_user_sessions_key(user_id), token)
return True, token_data return True, token_data
else: else:
logger.warning(f"[TokenStorage.validate_token] Токен не найден в Redis: {token_key}") logger.warning(f"[TokenStorage.validate_token] Токен не найден в Redis: {token_key}")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"[TokenStorage.validate_token] Ошибка при проверке токена: {e}") logger.error(f"[TokenStorage.validate_token] Ошибка при проверке токена: {e}")
return False, None return False, None
@ -200,30 +200,30 @@ class TokenStorage:
if not payload: if not payload:
logger.warning(f"[TokenStorage.invalidate_token] Токен не валиден (не удалось декодировать)") logger.warning(f"[TokenStorage.invalidate_token] Токен не валиден (не удалось декодировать)")
return False return False
user_id = payload.user_id user_id = payload.user_id
username = payload.username username = payload.username
# Формируем ключи для Redis в обоих форматах # Формируем ключи для Redis в обоих форматах
token_key = cls._make_token_key(user_id, username, token) token_key = cls._make_token_key(user_id, username, token)
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
# Удаляем токен из Redis в обоих форматах # Удаляем токен из Redis в обоих форматах
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.delete(token_key) pipeline.delete(token_key)
pipeline.delete(session_key) pipeline.delete(session_key)
pipeline.srem(user_sessions_key, token) pipeline.srem(user_sessions_key, token)
results = await pipeline.execute() results = await pipeline.execute()
success = any(results) success = any(results)
if success: if success:
logger.info(f"[TokenStorage.invalidate_token] Токен успешно инвалидирован для пользователя {user_id}") logger.info(f"[TokenStorage.invalidate_token] Токен успешно инвалидирован для пользователя {user_id}")
else: else:
logger.warning(f"[TokenStorage.invalidate_token] Токен не найден: {token_key}") logger.warning(f"[TokenStorage.invalidate_token] Токен не найден: {token_key}")
return success return success
except Exception as e: except Exception as e:
logger.error(f"[TokenStorage.invalidate_token] Ошибка при инвалидации токена: {e}") logger.error(f"[TokenStorage.invalidate_token] Ошибка при инвалидации токена: {e}")
return False return False
@ -243,11 +243,11 @@ class TokenStorage:
# Получаем список сессий пользователя # Получаем список сессий пользователя
user_sessions_key = cls._make_user_sessions_key(user_id) user_sessions_key = cls._make_user_sessions_key(user_id)
tokens = await redis.smembers(user_sessions_key) tokens = await redis.smembers(user_sessions_key)
if not tokens: if not tokens:
logger.warning(f"[TokenStorage.invalidate_all_tokens] Нет активных сессий пользователя {user_id}") logger.warning(f"[TokenStorage.invalidate_all_tokens] Нет активных сессий пользователя {user_id}")
return 0 return 0
count = 0 count = 0
for token in tokens: for token in tokens:
# Декодируем JWT токен # Декодируем JWT токен
@ -255,28 +255,28 @@ class TokenStorage:
payload = JWTCodec.decode(token) payload = JWTCodec.decode(token)
if payload: if payload:
username = payload.username username = payload.username
# Формируем ключи для Redis # Формируем ключи для Redis
token_key = cls._make_token_key(user_id, username, token) token_key = cls._make_token_key(user_id, username, token)
session_key = cls._make_session_key(user_id, token) session_key = cls._make_session_key(user_id, token)
# Удаляем токен из Redis # Удаляем токен из Redis
pipeline = redis.pipeline() pipeline = redis.pipeline()
pipeline.delete(token_key) pipeline.delete(token_key)
pipeline.delete(session_key) pipeline.delete(session_key)
results = await pipeline.execute() results = await pipeline.execute()
count += 1 count += 1
except Exception as e: except Exception as e:
logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при обработке токена: {e}") logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при обработке токена: {e}")
continue continue
# Удаляем список сессий пользователя # Удаляем список сессий пользователя
await redis.delete(user_sessions_key) await redis.delete(user_sessions_key)
logger.info(f"[TokenStorage.invalidate_all_tokens] Инвалидировано {count} токенов пользователя {user_id}") logger.info(f"[TokenStorage.invalidate_all_tokens] Инвалидировано {count} токенов пользователя {user_id}")
return count return count
except Exception as e: except Exception as e:
logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при инвалидации всех токенов: {e}") logger.error(f"[TokenStorage.invalidate_all_tokens] Ошибка при инвалидации всех токенов: {e}")
return 0 return 0

6
cache/precache.py vendored
View File

@ -3,8 +3,8 @@ import json
from sqlalchemy import and_, join, select from sqlalchemy import and_, join, select
from cache.cache import cache_author, cache_topic
from auth.orm import Author, AuthorFollower from auth.orm import Author, AuthorFollower
from cache.cache import cache_author, cache_topic
from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower, ShoutTopic from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower, ShoutTopic
from orm.topic import Topic, TopicFollower from orm.topic import Topic, TopicFollower
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
@ -29,9 +29,7 @@ async def precache_authors_followers(author_id, session):
async def precache_authors_follows(author_id, session): async def precache_authors_follows(author_id, session):
follows_topics_query = select(TopicFollower.topic).where(TopicFollower.follower == author_id) follows_topics_query = select(TopicFollower.topic).where(TopicFollower.follower == author_id)
follows_authors_query = select(AuthorFollower.author).where(AuthorFollower.follower == author_id) follows_authors_query = select(AuthorFollower.author).where(AuthorFollower.follower == author_id)
follows_shouts_query = select(ShoutReactionsFollower.shout).where( follows_shouts_query = select(ShoutReactionsFollower.shout).where(ShoutReactionsFollower.follower == author_id)
ShoutReactionsFollower.follower == author_id
)
follows_topics = {row[0] for row in session.execute(follows_topics_query) if row[0]} follows_topics = {row[0] for row in session.execute(follows_topics_query) if row[0]}
follows_authors = {row[0] for row in session.execute(follows_authors_query) if row[0]} follows_authors = {row[0] for row in session.execute(follows_authors_query) if row[0]}

2
cache/triggers.py vendored
View File

@ -1,7 +1,7 @@
from sqlalchemy import event from sqlalchemy import event
from cache.revalidator import revalidation_manager
from auth.orm import Author, AuthorFollower from auth.orm import Author, AuthorFollower
from cache.revalidator import revalidation_manager
from orm.reaction import Reaction, ReactionKind from orm.reaction import Reaction, ReactionKind
from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower
from orm.topic import Topic, TopicFollower from orm.topic import Topic, TopicFollower

43
dev.py
View File

@ -1,17 +1,19 @@
import os import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from utils.logger import root_logger as logger
from granian import Granian from granian import Granian
from utils.logger import root_logger as logger
def check_mkcert_installed(): def check_mkcert_installed():
""" """
Проверяет, установлен ли инструмент mkcert в системе Проверяет, установлен ли инструмент mkcert в системе
Returns: Returns:
bool: True если mkcert установлен, иначе False bool: True если mkcert установлен, иначе False
>>> check_mkcert_installed() # doctest: +SKIP >>> check_mkcert_installed() # doctest: +SKIP
True True
""" """
@ -21,18 +23,19 @@ def check_mkcert_installed():
except FileNotFoundError: except FileNotFoundError:
return False return False
def generate_certificates(domain="localhost", cert_file="localhost.pem", key_file="localhost-key.pem"): def generate_certificates(domain="localhost", cert_file="localhost.pem", key_file="localhost-key.pem"):
""" """
Генерирует сертификаты с использованием mkcert Генерирует сертификаты с использованием mkcert
Args: Args:
domain: Домен для сертификата domain: Домен для сертификата
cert_file: Имя файла сертификата cert_file: Имя файла сертификата
key_file: Имя файла ключа key_file: Имя файла ключа
Returns: Returns:
tuple: (cert_file, key_file) пути к созданным файлам tuple: (cert_file, key_file) пути к созданным файлам
>>> generate_certificates() # doctest: +SKIP >>> generate_certificates() # doctest: +SKIP
('localhost.pem', 'localhost-key.pem') ('localhost.pem', 'localhost-key.pem')
""" """
@ -40,7 +43,7 @@ def generate_certificates(domain="localhost", cert_file="localhost.pem", key_fil
if os.path.exists(cert_file) and os.path.exists(key_file): if os.path.exists(cert_file) and os.path.exists(key_file):
logger.info(f"Сертификаты уже существуют: {cert_file}, {key_file}") logger.info(f"Сертификаты уже существуют: {cert_file}, {key_file}")
return cert_file, key_file return cert_file, key_file
# Проверяем, установлен ли mkcert # Проверяем, установлен ли mkcert
if not check_mkcert_installed(): if not check_mkcert_installed():
logger.error("mkcert не установлен. Установите mkcert с помощью команды:") logger.error("mkcert не установлен. Установите mkcert с помощью команды:")
@ -49,37 +52,38 @@ def generate_certificates(domain="localhost", cert_file="localhost.pem", key_fil
logger.error(" Windows: choco install mkcert") logger.error(" Windows: choco install mkcert")
logger.error("После установки выполните: mkcert -install") logger.error("После установки выполните: mkcert -install")
return None, None return None, None
try: try:
# Запускаем mkcert для создания сертификата # Запускаем mkcert для создания сертификата
logger.info(f"Создание сертификатов для {domain} с помощью mkcert...") logger.info(f"Создание сертификатов для {domain} с помощью mkcert...")
result = subprocess.run( result = subprocess.run(
["mkcert", "-cert-file", cert_file, "-key-file", key_file, domain], ["mkcert", "-cert-file", cert_file, "-key-file", key_file, domain],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True text=True,
) )
if result.returncode != 0: if result.returncode != 0:
logger.error(f"Ошибка при создании сертификатов: {result.stderr}") logger.error(f"Ошибка при создании сертификатов: {result.stderr}")
return None, None return None, None
logger.info(f"Сертификаты созданы: {cert_file}, {key_file}") logger.info(f"Сертификаты созданы: {cert_file}, {key_file}")
return cert_file, key_file return cert_file, key_file
except Exception as e: except Exception as e:
logger.error(f"Не удалось создать сертификаты: {str(e)}") logger.error(f"Не удалось создать сертификаты: {str(e)}")
return None, None return None, None
def run_server(host="0.0.0.0", port=8000, workers=1): def run_server(host="0.0.0.0", port=8000, workers=1):
""" """
Запускает сервер Granian с поддержкой HTTPS при необходимости Запускает сервер Granian с поддержкой HTTPS при необходимости
Args: Args:
host: Хост для запуска сервера host: Хост для запуска сервера
port: Порт для запуска сервера port: Порт для запуска сервера
use_https: Флаг использования HTTPS use_https: Флаг использования HTTPS
workers: Количество рабочих процессов workers: Количество рабочих процессов
>>> run_server(use_https=True) # doctest: +SKIP >>> run_server(use_https=True) # doctest: +SKIP
""" """
# Проблема с многопроцессорным режимом - не поддерживает локальные объекты приложений # Проблема с многопроцессорным режимом - не поддерживает локальные объекты приложений
@ -87,16 +91,16 @@ def run_server(host="0.0.0.0", port=8000, workers=1):
if workers > 1: if workers > 1:
logger.warning("Многопроцессорный режим может вызвать проблемы сериализации приложения. Использую 1 процесс.") logger.warning("Многопроцессорный режим может вызвать проблемы сериализации приложения. Использую 1 процесс.")
workers = 1 workers = 1
# При проблемах с ASGI можно попробовать использовать Uvicorn как запасной вариант # При проблемах с ASGI можно попробовать использовать Uvicorn как запасной вариант
try: try:
# Генерируем сертификаты с помощью mkcert # Генерируем сертификаты с помощью mkcert
cert_file, key_file = generate_certificates() cert_file, key_file = generate_certificates()
if not cert_file or not key_file: if not cert_file or not key_file:
logger.error("Не удалось сгенерировать сертификаты для HTTPS") logger.error("Не удалось сгенерировать сертификаты для HTTPS")
return return
logger.info(f"Запуск HTTPS сервера на https://{host}:{port} с использованием Granian") logger.info(f"Запуск HTTPS сервера на https://{host}:{port} с использованием Granian")
# Запускаем Granian сервер с явным указанием ASGI # Запускаем Granian сервер с явным указанием ASGI
server = Granian( server = Granian(
@ -104,7 +108,7 @@ def run_server(host="0.0.0.0", port=8000, workers=1):
port=port, port=port,
workers=workers, workers=workers,
interface="asgi", interface="asgi",
target="main:app", target="main:app",
ssl_cert=Path(cert_file), ssl_cert=Path(cert_file),
ssl_key=Path(key_file), ssl_key=Path(key_file),
) )
@ -113,5 +117,6 @@ def run_server(host="0.0.0.0", port=8000, workers=1):
# В случае проблем с Granian, пробуем запустить через Uvicorn # В случае проблем с Granian, пробуем запустить через Uvicorn
logger.error(f"Ошибка при запуске Granian: {str(e)}") logger.error(f"Ошибка при запуске Granian: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
run_server() run_server()

67
main.py
View File

@ -5,19 +5,18 @@ from os.path import exists, join
from ariadne import load_schema_from_path, make_executable_schema from ariadne import load_schema_from_path, make_executable_schema
from ariadne.asgi import GraphQL from ariadne.asgi import GraphQL
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route
from starlette.staticfiles import StaticFiles
from auth.handler import EnhancedGraphQLHTTPHandler from auth.handler import EnhancedGraphQLHTTPHandler
from auth.internal import InternalAuthentication from auth.internal import InternalAuthentication
from auth.middleware import auth_middleware, AuthMiddleware from auth.middleware import AuthMiddleware, auth_middleware
from starlette.applications import Starlette
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from cache.precache import precache_data from cache.precache import precache_data
from cache.revalidator import revalidation_manager from cache.revalidator import revalidation_manager
from services.exception import ExceptionHandlerMiddleware from services.exception import ExceptionHandlerMiddleware
@ -25,8 +24,8 @@ from services.redis import redis
from services.schema import create_all_tables, resolvers from services.schema import create_all_tables, resolvers
from services.search import check_search_service, initialize_search_index_background, search_service from services.search import check_search_service, initialize_search_index_background, search_service
from services.viewed import ViewedStorage from services.viewed import ViewedStorage
from utils.logger import root_logger as logger
from settings import DEV_SERVER_PID_FILE_NAME from settings import DEV_SERVER_PID_FILE_NAME
from utils.logger import root_logger as logger
DEVMODE = os.getenv("DOKKU_APP_TYPE", "false").lower() == "false" DEVMODE = os.getenv("DOKKU_APP_TYPE", "false").lower() == "false"
DIST_DIR = join(os.path.dirname(__file__), "dist") # Директория для собранных файлов DIST_DIR = join(os.path.dirname(__file__), "dist") # Директория для собранных файлов
@ -46,14 +45,14 @@ middleware = [
Middleware( Middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[ allow_origins=[
"https://localhost:3000", "https://localhost:3000",
"https://testing.discours.io", "https://testing.discours.io",
"https://discours.io", "https://discours.io",
"https://new.discours.io", "https://new.discours.io",
"https://discours.ru", "https://discours.ru",
"https://new.discours.ru" "https://new.discours.ru",
], ],
allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS
allow_headers=["*"], allow_headers=["*"],
allow_credentials=True, allow_credentials=True,
), ),
@ -65,33 +64,29 @@ middleware = [
# Создаем экземпляр GraphQL с улучшенным обработчиком # Создаем экземпляр GraphQL с улучшенным обработчиком
graphql_app = GraphQL( graphql_app = GraphQL(schema, debug=DEVMODE, http_handler=EnhancedGraphQLHTTPHandler())
schema,
debug=DEVMODE,
http_handler=EnhancedGraphQLHTTPHandler()
)
# Оборачиваем GraphQL-обработчик для лучшей обработки ошибок # Оборачиваем GraphQL-обработчик для лучшей обработки ошибок
async def graphql_handler(request: Request): async def graphql_handler(request: Request):
""" """
Обработчик GraphQL запросов с поддержкой middleware и обработкой ошибок. Обработчик GraphQL запросов с поддержкой middleware и обработкой ошибок.
Выполняет: Выполняет:
1. Проверку метода запроса (GET, POST, OPTIONS) 1. Проверку метода запроса (GET, POST, OPTIONS)
2. Обработку GraphQL запроса через ariadne 2. Обработку GraphQL запроса через ariadne
3. Применение middleware для корректной обработки cookie и авторизации 3. Применение middleware для корректной обработки cookie и авторизации
4. Обработку исключений и формирование ответа 4. Обработку исключений и формирование ответа
Args: Args:
request: Starlette Request объект request: Starlette Request объект
Returns: Returns:
Response: объект ответа (обычно JSONResponse) Response: объект ответа (обычно JSONResponse)
""" """
if request.method not in ["GET", "POST", "OPTIONS"]: if request.method not in ["GET", "POST", "OPTIONS"]:
return JSONResponse({"error": "Method Not Allowed by main.py"}, status_code=405) return JSONResponse({"error": "Method Not Allowed by main.py"}, status_code=405)
# Проверяем, что все необходимые middleware корректно отработали # Проверяем, что все необходимые middleware корректно отработали
if not hasattr(request, "scope") or "auth" not in request.scope: if not hasattr(request, "scope") or "auth" not in request.scope:
logger.warning("[graphql] AuthMiddleware не обработал запрос перед GraphQL обработчиком") logger.warning("[graphql] AuthMiddleware не обработал запрос перед GraphQL обработчиком")
@ -99,7 +94,7 @@ async def graphql_handler(request: Request):
try: try:
# Обрабатываем запрос через GraphQL приложение # Обрабатываем запрос через GraphQL приложение
result = await graphql_app.handle_request(request) result = await graphql_app.handle_request(request)
# Применяем middleware для установки cookie # Применяем middleware для установки cookie
# Используем метод process_result из auth_middleware для корректной обработки # Используем метод process_result из auth_middleware для корректной обработки
# cookie на основе результатов операций login/logout # cookie на основе результатов операций login/logout
@ -111,6 +106,7 @@ async def graphql_handler(request: Request):
logger.error(f"GraphQL error: {str(e)}") logger.error(f"GraphQL error: {str(e)}")
# Логируем более подробную информацию для отладки # Логируем более подробную информацию для отладки
import traceback import traceback
logger.debug(f"GraphQL error traceback: {traceback.format_exc()}") logger.debug(f"GraphQL error traceback: {traceback.format_exc()}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
@ -127,6 +123,7 @@ async def shutdown():
# Удаляем PID-файл, если он существует # Удаляем PID-файл, если он существует
from settings import DEV_SERVER_PID_FILE_NAME from settings import DEV_SERVER_PID_FILE_NAME
if exists(DEV_SERVER_PID_FILE_NAME): if exists(DEV_SERVER_PID_FILE_NAME):
os.unlink(DEV_SERVER_PID_FILE_NAME) os.unlink(DEV_SERVER_PID_FILE_NAME)
@ -134,12 +131,12 @@ async def shutdown():
async def dev_start(): async def dev_start():
""" """
Инициализация сервера в DEV режиме. Инициализация сервера в DEV режиме.
Функция: Функция:
1. Проверяет наличие DEV режима 1. Проверяет наличие DEV режима
2. Создает PID-файл для отслеживания процесса 2. Создает PID-файл для отслеживания процесса
3. Логирует информацию о старте сервера 3. Логирует информацию о старте сервера
Используется только при запуске сервера с флагом "dev". Используется только при запуске сервера с флагом "dev".
""" """
try: try:
@ -151,6 +148,7 @@ async def dev_start():
old_pid = int(f.read().strip()) old_pid = int(f.read().strip())
# Проверяем, существует ли процесс с таким PID # Проверяем, существует ли процесс с таким PID
import signal import signal
try: try:
os.kill(old_pid, 0) # Сигнал 0 только проверяет существование процесса os.kill(old_pid, 0) # Сигнал 0 только проверяет существование процесса
print(f"[warning] DEV server already running with PID {old_pid}") print(f"[warning] DEV server already running with PID {old_pid}")
@ -158,7 +156,7 @@ async def dev_start():
print(f"[info] Stale PID file found, previous process {old_pid} not running") print(f"[info] Stale PID file found, previous process {old_pid} not running")
except (ValueError, FileNotFoundError): except (ValueError, FileNotFoundError):
print(f"[warning] Invalid PID file found, recreating") print(f"[warning] Invalid PID file found, recreating")
# Создаем или перезаписываем PID-файл # Создаем или перезаписываем PID-файл
with open(pid_path, "w", encoding="utf-8") as f: with open(pid_path, "w", encoding="utf-8") as f:
f.write(str(os.getpid())) f.write(str(os.getpid()))
@ -172,16 +170,16 @@ async def dev_start():
async def lifespan(_app): async def lifespan(_app):
""" """
Функция жизненного цикла приложения. Функция жизненного цикла приложения.
Обеспечивает: Обеспечивает:
1. Инициализацию всех необходимых сервисов и компонентов 1. Инициализацию всех необходимых сервисов и компонентов
2. Предзагрузку кеша данных 2. Предзагрузку кеша данных
3. Подключение к Redis и поисковому сервису 3. Подключение к Redis и поисковому сервису
4. Корректное завершение работы при остановке сервера 4. Корректное завершение работы при остановке сервера
Args: Args:
_app: экземпляр Starlette приложения _app: экземпляр Starlette приложения
Yields: Yields:
None: генератор для управления жизненным циклом None: генератор для управления жизненным циклом
""" """
@ -213,11 +211,12 @@ async def lifespan(_app):
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
print("[lifespan] Shutdown complete") print("[lifespan] Shutdown complete")
# Обновляем маршрут в Starlette # Обновляем маршрут в Starlette
app = Starlette( app = Starlette(
routes=[ routes=[
Route("/graphql", graphql_handler, methods=["GET", "POST", "OPTIONS"]), Route("/graphql", graphql_handler, methods=["GET", "POST", "OPTIONS"]),
Mount("/", app=StaticFiles(directory=DIST_DIR, html=True)) Mount("/", app=StaticFiles(directory=DIST_DIR, html=True)),
], ],
lifespan=lifespan, lifespan=lifespan,
middleware=middleware, # Явно указываем список middleware middleware=middleware, # Явно указываем список middleware

View File

@ -66,11 +66,7 @@ class CommunityStats:
def shouts(self): def shouts(self):
from orm.shout import Shout from orm.shout import Shout
return ( return self.community.session.query(func.count(Shout.id)).filter(Shout.community == self.community.id).scalar()
self.community.session.query(func.count(Shout.id))
.filter(Shout.community == self.community.id)
.scalar()
)
@property @property
def followers(self): def followers(self):

View File

@ -77,7 +77,7 @@ class Shout(Base):
slug (str) slug (str)
cover (str) : "Cover image url" cover (str) : "Cover image url"
cover_caption (str) : "Cover image alt caption" cover_caption (str) : "Cover image alt caption"
lead (str) lead (str)
title (str) title (str)
subtitle (str) subtitle (str)
layout (str) layout (str)

View File

@ -1,4 +1,15 @@
from cache.triggers import events_register from cache.triggers import events_register
from resolvers.admin import (
admin_get_roles,
admin_get_users,
)
from resolvers.auth import (
confirm_email,
get_current_user,
login,
register_by_email,
send_link,
)
from resolvers.author import ( # search_authors, from resolvers.author import ( # search_authors,
get_author, get_author,
get_author_followers, get_author_followers,
@ -16,8 +27,8 @@ from resolvers.draft import (
delete_draft, delete_draft,
load_drafts, load_drafts,
publish_draft, publish_draft,
update_draft,
unpublish_draft, unpublish_draft,
update_draft,
) )
from resolvers.editor import ( from resolvers.editor import (
unpublish_shout, unpublish_shout,
@ -62,19 +73,6 @@ from resolvers.topic import (
get_topics_by_community, get_topics_by_community,
) )
from resolvers.auth import (
get_current_user,
confirm_email,
register_by_email,
send_link,
login,
)
from resolvers.admin import (
admin_get_users,
admin_get_roles,
)
events_register() events_register()
__all__ = [ __all__ = [
@ -84,11 +82,9 @@ __all__ = [
"register_by_email", "register_by_email",
"send_link", "send_link",
"login", "login",
# admin # admin
"admin_get_users", "admin_get_users",
"admin_get_roles", "admin_get_roles",
# author # author
"get_author", "get_author",
"get_author_followers", "get_author_followers",
@ -100,11 +96,9 @@ __all__ = [
"load_authors_search", "load_authors_search",
"update_author", "update_author",
# "search_authors", # "search_authors",
# community # community
"get_community", "get_community",
"get_communities_all", "get_communities_all",
# topic # topic
"get_topic", "get_topic",
"get_topics_all", "get_topics_all",
@ -112,14 +106,12 @@ __all__ = [
"get_topics_by_author", "get_topics_by_author",
"get_topic_followers", "get_topic_followers",
"get_topic_authors", "get_topic_authors",
# reader # reader
"get_shout", "get_shout",
"load_shouts_by", "load_shouts_by",
"load_shouts_random_top", "load_shouts_random_top",
"load_shouts_search", "load_shouts_search",
"load_shouts_unrated", "load_shouts_unrated",
# feed # feed
"load_shouts_feed", "load_shouts_feed",
"load_shouts_coauthored", "load_shouts_coauthored",
@ -127,12 +119,10 @@ __all__ = [
"load_shouts_with_topic", "load_shouts_with_topic",
"load_shouts_followed_by", "load_shouts_followed_by",
"load_shouts_authored_by", "load_shouts_authored_by",
# follower # follower
"follow", "follow",
"unfollow", "unfollow",
"get_shout_followers", "get_shout_followers",
# reaction # reaction
"create_reaction", "create_reaction",
"update_reaction", "update_reaction",
@ -142,18 +132,15 @@ __all__ = [
"load_shout_ratings", "load_shout_ratings",
"load_comment_ratings", "load_comment_ratings",
"load_comments_branch", "load_comments_branch",
# notifier # notifier
"load_notifications", "load_notifications",
"notifications_seen_thread", "notifications_seen_thread",
"notifications_seen_after", "notifications_seen_after",
"notification_mark_seen", "notification_mark_seen",
# rating # rating
"rate_author", "rate_author",
"get_my_rates_comments", "get_my_rates_comments",
"get_my_rates_shouts", "get_my_rates_shouts",
# draft # draft
"load_drafts", "load_drafts",
"create_draft", "create_draft",

View File

@ -1,12 +1,13 @@
from math import ceil from math import ceil
from sqlalchemy import or_, cast, String
from graphql.error import GraphQLError from graphql.error import GraphQLError
from sqlalchemy import String, cast, or_
from auth.decorators import admin_auth_required from auth.decorators import admin_auth_required
from auth.orm import Author, AuthorRole, Role
from services.db import local_session from services.db import local_session
from services.schema import query, mutation
from auth.orm import Author, Role, AuthorRole
from services.env import EnvManager, EnvVariable from services.env import EnvManager, EnvVariable
from services.schema import mutation, query
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
@ -64,11 +65,9 @@ async def admin_get_users(_, info, limit=10, offset=0, search=None):
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
"slug": user.slug, "slug": user.slug,
"roles": [role.id for role in user.roles] "roles": [role.id for role in user.roles] if hasattr(user, "roles") and user.roles else [],
if hasattr(user, "roles") and user.roles
else [],
"created_at": user.created_at, "created_at": user.created_at,
"last_seen": user.last_seen "last_seen": user.last_seen,
} }
for user in users for user in users
], ],
@ -81,6 +80,7 @@ async def admin_get_users(_, info, limit=10, offset=0, search=None):
return result return result
except Exception as e: except Exception as e:
import traceback import traceback
logger.error(f"Ошибка при получении списка пользователей: {str(e)}") logger.error(f"Ошибка при получении списка пользователей: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise GraphQLError(f"Не удалось получить список пользователей: {str(e)}") raise GraphQLError(f"Не удалось получить список пользователей: {str(e)}")
@ -126,20 +126,20 @@ async def admin_get_roles(_, info):
async def get_env_variables(_, info): async def get_env_variables(_, info):
""" """
Получает список переменных окружения, сгруппированных по секциям Получает список переменных окружения, сгруппированных по секциям
Args: Args:
info: Контекст GraphQL запроса info: Контекст GraphQL запроса
Returns: Returns:
Список секций с переменными окружения Список секций с переменными окружения
""" """
try: try:
# Создаем экземпляр менеджера переменных окружения # Создаем экземпляр менеджера переменных окружения
env_manager = EnvManager() env_manager = EnvManager()
# Получаем все переменные # Получаем все переменные
sections = env_manager.get_all_variables() sections = env_manager.get_all_variables()
# Преобразуем к формату GraphQL API # Преобразуем к формату GraphQL API
result = [ result = [
{ {
@ -154,11 +154,11 @@ async def get_env_variables(_, info):
"isSecret": var.is_secret, "isSecret": var.is_secret,
} }
for var in section.variables for var in section.variables
] ],
} }
for section in sections for section in sections
] ]
return result return result
except Exception as e: except Exception as e:
logger.error(f"Ошибка при получении переменных окружения: {str(e)}") logger.error(f"Ошибка при получении переменных окружения: {str(e)}")
@ -170,27 +170,27 @@ async def get_env_variables(_, info):
async def update_env_variable(_, info, key, value): async def update_env_variable(_, info, key, value):
""" """
Обновляет значение переменной окружения Обновляет значение переменной окружения
Args: Args:
info: Контекст GraphQL запроса info: Контекст GraphQL запроса
key: Ключ переменной key: Ключ переменной
value: Новое значение value: Новое значение
Returns: Returns:
Boolean: результат операции Boolean: результат операции
""" """
try: try:
# Создаем экземпляр менеджера переменных окружения # Создаем экземпляр менеджера переменных окружения
env_manager = EnvManager() env_manager = EnvManager()
# Обновляем переменную # Обновляем переменную
result = env_manager.update_variable(key, value) result = env_manager.update_variable(key, value)
if result: if result:
logger.info(f"Переменная окружения '{key}' успешно обновлена") logger.info(f"Переменная окружения '{key}' успешно обновлена")
else: else:
logger.error(f"Не удалось обновить переменную окружения '{key}'") logger.error(f"Не удалось обновить переменную окружения '{key}'")
return result return result
except Exception as e: except Exception as e:
logger.error(f"Ошибка при обновлении переменной окружения: {str(e)}") logger.error(f"Ошибка при обновлении переменной окружения: {str(e)}")
@ -202,36 +202,32 @@ async def update_env_variable(_, info, key, value):
async def update_env_variables(_, info, variables): async def update_env_variables(_, info, variables):
""" """
Массовое обновление переменных окружения Массовое обновление переменных окружения
Args: Args:
info: Контекст GraphQL запроса info: Контекст GraphQL запроса
variables: Список переменных для обновления variables: Список переменных для обновления
Returns: Returns:
Boolean: результат операции Boolean: результат операции
""" """
try: try:
# Создаем экземпляр менеджера переменных окружения # Создаем экземпляр менеджера переменных окружения
env_manager = EnvManager() env_manager = EnvManager()
# Преобразуем входные данные в формат для менеджера # Преобразуем входные данные в формат для менеджера
env_variables = [ env_variables = [
EnvVariable( EnvVariable(key=var.get("key", ""), value=var.get("value", ""), type=var.get("type", "string"))
key=var.get("key", ""),
value=var.get("value", ""),
type=var.get("type", "string")
)
for var in variables for var in variables
] ]
# Обновляем переменные # Обновляем переменные
result = env_manager.update_variables(env_variables) result = env_manager.update_variables(env_variables)
if result: if result:
logger.info(f"Переменные окружения успешно обновлены ({len(variables)} шт.)") logger.info(f"Переменные окружения успешно обновлены ({len(variables)} шт.)")
else: else:
logger.error(f"Не удалось обновить переменные окружения") logger.error(f"Не удалось обновить переменные окружения")
return result return result
except Exception as e: except Exception as e:
logger.error(f"Ошибка при массовом обновлении переменных окружения: {str(e)}") logger.error(f"Ошибка при массовом обновлении переменных окружения: {str(e)}")
@ -243,90 +239,78 @@ async def update_env_variables(_, info, variables):
async def admin_update_user(_, info, user): async def admin_update_user(_, info, user):
""" """
Обновляет роли пользователя Обновляет роли пользователя
Args: Args:
info: Контекст GraphQL запроса info: Контекст GraphQL запроса
user: Данные для обновления пользователя (содержит id и roles) user: Данные для обновления пользователя (содержит id и roles)
Returns: Returns:
Boolean: результат операции или объект с ошибкой Boolean: результат операции или объект с ошибкой
""" """
try: try:
user_id = user.get("id") user_id = user.get("id")
roles = user.get("roles", []) roles = user.get("roles", [])
if not roles: if not roles:
logger.warning(f"Пользователю {user_id} не назначено ни одной роли. Доступ в систему будет заблокирован.") logger.warning(f"Пользователю {user_id} не назначено ни одной роли. Доступ в систему будет заблокирован.")
with local_session() as session: with local_session() as session:
# Получаем пользователя из базы данных # Получаем пользователя из базы данных
author = session.query(Author).filter(Author.id == user_id).first() author = session.query(Author).filter(Author.id == user_id).first()
if not author: if not author:
error_msg = f"Пользователь с ID {user_id} не найден" error_msg = f"Пользователь с ID {user_id} не найден"
logger.error(error_msg) logger.error(error_msg)
return { return {"success": False, "error": error_msg}
"success": False,
"error": error_msg
}
# Получаем ID сообщества по умолчанию # Получаем ID сообщества по умолчанию
default_community_id = 1 # Используем значение по умолчанию из модели AuthorRole default_community_id = 1 # Используем значение по умолчанию из модели AuthorRole
try: try:
# Очищаем текущие роли пользователя через ORM # Очищаем текущие роли пользователя через ORM
session.query(AuthorRole).filter(AuthorRole.author == user_id).delete() session.query(AuthorRole).filter(AuthorRole.author == user_id).delete()
session.flush() session.flush()
# Получаем все существующие роли, которые указаны для обновления # Получаем все существующие роли, которые указаны для обновления
role_objects = session.query(Role).filter(Role.id.in_(roles)).all() role_objects = session.query(Role).filter(Role.id.in_(roles)).all()
# Проверяем, все ли запрошенные роли найдены # Проверяем, все ли запрошенные роли найдены
found_role_ids = [role.id for role in role_objects] found_role_ids = [role.id for role in role_objects]
missing_roles = set(roles) - set(found_role_ids) missing_roles = set(roles) - set(found_role_ids)
if missing_roles: if missing_roles:
warning_msg = f"Некоторые роли не найдены в базе: {', '.join(missing_roles)}" warning_msg = f"Некоторые роли не найдены в базе: {', '.join(missing_roles)}"
logger.warning(warning_msg) logger.warning(warning_msg)
# Создаем новые записи в таблице author_role с указанием community # Создаем новые записи в таблице author_role с указанием community
for role in role_objects: for role in role_objects:
# Используем ORM для создания новых записей # Используем ORM для создания новых записей
author_role = AuthorRole( author_role = AuthorRole(community=default_community_id, author=user_id, role=role.id)
community=default_community_id,
author=user_id,
role=role.id
)
session.add(author_role) session.add(author_role)
# Сохраняем изменения в базе данных # Сохраняем изменения в базе данных
session.commit() session.commit()
# Проверяем, добавлена ли пользователю роль reader # Проверяем, добавлена ли пользователю роль reader
has_reader = 'reader' in [role.id for role in role_objects] has_reader = "reader" in [role.id for role in role_objects]
if not has_reader: if not has_reader:
logger.warning(f"Пользователю {author.email or author.id} не назначена роль 'reader'. Доступ в систему будет ограничен.") logger.warning(
f"Пользователю {author.email or author.id} не назначена роль 'reader'. Доступ в систему будет ограничен."
)
logger.info(f"Роли пользователя {author.email or author.id} обновлены: {', '.join(found_role_ids)}") logger.info(f"Роли пользователя {author.email or author.id} обновлены: {', '.join(found_role_ids)}")
return { return {"success": True}
"success": True
}
except Exception as e: except Exception as e:
# Обработка вложенных исключений # Обработка вложенных исключений
session.rollback() session.rollback()
error_msg = f"Ошибка при изменении ролей: {str(e)}" error_msg = f"Ошибка при изменении ролей: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
return { return {"success": False, "error": error_msg}
"success": False,
"error": error_msg
}
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = f"Ошибка при обновлении ролей пользователя: {str(e)}" error_msg = f"Ошибка при обновлении ролей пользователя: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return { return {"success": False, "error": error_msg}
"success": False,
"error": error_msg
}

View File

@ -1,46 +1,48 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import time import time
import traceback import traceback
from utils.logger import root_logger as logger
from graphql.type import GraphQLResolveInfo from graphql.type import GraphQLResolveInfo
# import asyncio # Убираем, так как резолвер будет синхронным
from services.auth import login_required
from auth.credentials import AuthCredentials from auth.credentials import AuthCredentials
from auth.email import send_auth_email from auth.email import send_auth_email
from auth.exceptions import InvalidToken, ObjectNotExist from auth.exceptions import InvalidToken, ObjectNotExist
from auth.identity import Identity, Password from auth.identity import Identity, Password
from auth.internal import verify_internal_auth
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from auth.tokenstorage import TokenStorage
from auth.orm import Author, Role from auth.orm import Author, Role
from auth.sessions import SessionManager
from auth.tokenstorage import TokenStorage
# import asyncio # Убираем, так как резолвер будет синхронным
from services.auth import login_required
from services.db import local_session from services.db import local_session
from services.schema import mutation, query from services.schema import mutation, query
from settings import ( from settings import (
ADMIN_EMAILS, ADMIN_EMAILS,
SESSION_TOKEN_HEADER,
SESSION_COOKIE_NAME,
SESSION_COOKIE_SECURE,
SESSION_COOKIE_SAMESITE,
SESSION_COOKIE_MAX_AGE,
SESSION_COOKIE_HTTPONLY, SESSION_COOKIE_HTTPONLY,
SESSION_COOKIE_MAX_AGE,
SESSION_COOKIE_NAME,
SESSION_COOKIE_SAMESITE,
SESSION_COOKIE_SECURE,
SESSION_TOKEN_HEADER,
) )
from utils.generate_slug import generate_unique_slug from utils.generate_slug import generate_unique_slug
from auth.sessions import SessionManager from utils.logger import root_logger as logger
from auth.internal import verify_internal_auth
@mutation.field("getSession") @mutation.field("getSession")
@login_required @login_required
async def get_current_user(_, info): async def get_current_user(_, info):
""" """
Получает информацию о текущем пользователе. Получает информацию о текущем пользователе.
Требует авторизации через декоратор login_required. Требует авторизации через декоратор login_required.
Args: Args:
_: Родительский объект (не используется) _: Родительский объект (не используется)
info: Контекст GraphQL запроса info: Контекст GraphQL запроса
Returns: Returns:
dict: Объект с токеном и данными автора с добавленной статистикой dict: Объект с токеном и данными автора с добавленной статистикой
""" """
@ -49,68 +51,73 @@ async def get_current_user(_, info):
if not author_id: if not author_id:
logger.error("[getSession] Пользователь не авторизован") logger.error("[getSession] Пользователь не авторизован")
from graphql.error import GraphQLError from graphql.error import GraphQLError
raise GraphQLError("Требуется авторизация") raise GraphQLError("Требуется авторизация")
# Получаем токен из заголовка # Получаем токен из заголовка
req = info.context.get("request") req = info.context.get("request")
token = req.headers.get(SESSION_TOKEN_HEADER) token = req.headers.get(SESSION_TOKEN_HEADER)
if token and token.startswith("Bearer "): if token and token.startswith("Bearer "):
token = token.split("Bearer ")[-1].strip() token = token.split("Bearer ")[-1].strip()
# Получаем данные автора # Получаем данные автора
author = info.context.get("author") author = info.context.get("author")
# Если автор не найден в контексте, пробуем получить из БД с добавлением статистики # Если автор не найден в контексте, пробуем получить из БД с добавлением статистики
if not author: if not author:
logger.debug(f"[getSession] Автор не найден в контексте для пользователя {user_id}, получаем из БД") logger.debug(f"[getSession] Автор не найден в контексте для пользователя {author_id}, получаем из БД")
try: try:
# Используем функцию get_with_stat для получения автора со статистикой # Используем функцию get_with_stat для получения автора со статистикой
from sqlalchemy import select from sqlalchemy import select
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
q = select(Author).where(Author.id == user_id) q = select(Author).where(Author.id == author_id)
authors_with_stat = get_with_stat(q) authors_with_stat = get_with_stat(q)
if authors_with_stat and len(authors_with_stat) > 0: if authors_with_stat and len(authors_with_stat) > 0:
author = authors_with_stat[0] author = authors_with_stat[0]
# Обновляем last_seen отдельной транзакцией # Обновляем last_seen отдельной транзакцией
with local_session() as session: with local_session() as session:
author_db = session.query(Author).filter(Author.id == user_id).first() author_db = session.query(Author).filter(Author.id == author_id).first()
if author_db: if author_db:
author_db.last_seen = int(time.time()) author_db.last_seen = int(time.time())
session.commit() session.commit()
else: else:
logger.error(f"[getSession] Автор с ID {user_id} не найден в БД") logger.error(f"[getSession] Автор с ID {author_id} не найден в БД")
from graphql.error import GraphQLError from graphql.error import GraphQLError
raise GraphQLError("Пользователь не найден") raise GraphQLError("Пользователь не найден")
except Exception as e: except Exception as e:
logger.error(f"[getSession] Ошибка при получении автора из БД: {e}", exc_info=True) logger.error(f"[getSession] Ошибка при получении автора из БД: {e}", exc_info=True)
from graphql.error import GraphQLError from graphql.error import GraphQLError
raise GraphQLError("Ошибка при получении данных пользователя") raise GraphQLError("Ошибка при получении данных пользователя")
else: else:
# Если автор уже есть в контексте, добавляем статистику # Если автор уже есть в контексте, добавляем статистику
try: try:
from sqlalchemy import select from sqlalchemy import select
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
q = select(Author).where(Author.id == user_id) q = select(Author).where(Author.id == author_id)
authors_with_stat = get_with_stat(q) authors_with_stat = get_with_stat(q)
if authors_with_stat and len(authors_with_stat) > 0: if authors_with_stat and len(authors_with_stat) > 0:
# Обновляем только статистику # Обновляем только статистику
author.stat = authors_with_stat[0].stat author.stat = authors_with_stat[0].stat
except Exception as e: except Exception as e:
logger.warning(f"[getSession] Не удалось добавить статистику к автору: {e}") logger.warning(f"[getSession] Не удалось добавить статистику к автору: {e}")
# Возвращаем данные сессии # Возвращаем данные сессии
logger.info(f"[getSession] Успешно получена сессия для пользователя {user_id}") logger.info(f"[getSession] Успешно получена сессия для пользователя {author_id}")
return {"token": token or '', "author": author} return {"token": token or "", "author": author}
@mutation.field("confirmEmail") @mutation.field("confirmEmail")
async def confirm_email(_, info, token): async def confirm_email(_, info, token):
"""confirm owning email address""" """confirm owning email address"""
try: try:
@ -118,26 +125,26 @@ async def confirm_email(_, info, token):
payload = JWTCodec.decode(token) payload = JWTCodec.decode(token)
user_id = payload.user_id user_id = payload.user_id
username = payload.username username = payload.username
# Если TokenStorage.get асинхронный, это нужно будет переделать или вызывать синхронно # Если TokenStorage.get асинхронный, это нужно будет переделать или вызывать синхронно
# Для теста пока оставим, но это потенциальная точка отказа в синхронном резолвере # Для теста пока оставим, но это потенциальная точка отказа в синхронном резолвере
token_key = f"{user_id}-{username}-{token}" token_key = f"{user_id}-{username}-{token}"
await TokenStorage.get(token_key) await TokenStorage.get(token_key)
with local_session() as session: with local_session() as session:
user = session.query(Author).where(Author.id == user_id).first() user = session.query(Author).where(Author.id == user_id).first()
if not user: if not user:
logger.warning(f"[auth] confirmEmail: Пользователь с ID {user_id} не найден.") logger.warning(f"[auth] confirmEmail: Пользователь с ID {user_id} не найден.")
return {"success": False, "token": None, "author": None, "error": "Пользователь не найден"} return {"success": False, "token": None, "author": None, "error": "Пользователь не найден"}
# Создаем сессионный токен с новым форматом вызова и явным временем истечения # Создаем сессионный токен с новым форматом вызова и явным временем истечения
device_info = {"email": user.email} if hasattr(user, "email") else None device_info = {"email": user.email} if hasattr(user, "email") else None
session_token = await TokenStorage.create_session( session_token = await TokenStorage.create_session(
user_id=str(user_id), user_id=str(user_id),
username=user.username or user.email or user.slug or username, username=user.username or user.email or user.slug or username,
device_info=device_info device_info=device_info,
) )
user.email_verified = True user.email_verified = True
user.last_seen = int(time.time()) user.last_seen = int(time.time())
session.add(user) session.add(user)
@ -155,7 +162,7 @@ async def confirm_email(_, info, token):
"token": None, "token": None,
"author": None, "author": None,
"error": f"Ошибка подтверждения email: {str(e)}", "error": f"Ошибка подтверждения email: {str(e)}",
} }
def create_user(user_dict): def create_user(user_dict):
@ -231,9 +238,7 @@ async def register_by_email(_, _info, email: str, password: str = "", name: str
try: try:
# Если auth_send_link асинхронный... # Если auth_send_link асинхронный...
await send_link(_, _info, email) await send_link(_, _info, email)
logger.info( logger.info(f"[auth] registerUser: Пользователь {email} зарегистрирован, ссылка для подтверждения отправлена.")
f"[auth] registerUser: Пользователь {email} зарегистрирован, ссылка для подтверждения отправлена."
)
# При регистрации возвращаем данные самому пользователю, поэтому не фильтруем # При регистрации возвращаем данные самому пользователю, поэтому не фильтруем
return { return {
"success": True, "success": True,
@ -306,7 +311,7 @@ async def login(_, info, email: str, password: str):
logger.info( logger.info(
f"[auth] login: Найден автор {email}, id={author.id}, имя={author.name}, пароль есть: {bool(author.password)}" f"[auth] login: Найден автор {email}, id={author.id}, имя={author.name}, пароль есть: {bool(author.password)}"
) )
# Проверяем наличие роли reader # Проверяем наличие роли reader
has_reader_role = False has_reader_role = False
if hasattr(author, "roles") and author.roles: if hasattr(author, "roles") and author.roles:
@ -314,12 +319,12 @@ async def login(_, info, email: str, password: str):
if role.id == "reader": if role.id == "reader":
has_reader_role = True has_reader_role = True
break break
# Если у пользователя нет роли reader и он не админ, запрещаем вход # Если у пользователя нет роли reader и он не админ, запрещаем вход
if not has_reader_role: if not has_reader_role:
# Проверяем, есть ли роль admin или super # Проверяем, есть ли роль admin или super
is_admin = author.email in ADMIN_EMAILS.split(",") is_admin = author.email in ADMIN_EMAILS.split(",")
if not is_admin: if not is_admin:
logger.warning(f"[auth] login: У пользователя {email} нет роли 'reader', в доступе отказано") logger.warning(f"[auth] login: У пользователя {email} нет роли 'reader', в доступе отказано")
return { return {
@ -365,9 +370,7 @@ async def login(_, info, email: str, password: str):
or not hasattr(valid_author, "username") or not hasattr(valid_author, "username")
and not hasattr(valid_author, "email") and not hasattr(valid_author, "email")
): ):
logger.error( logger.error(f"[auth] login: Объект автора не содержит необходимых атрибутов: {valid_author}")
f"[auth] login: Объект автора не содержит необходимых атрибутов: {valid_author}"
)
return { return {
"success": False, "success": False,
"token": None, "token": None,
@ -380,7 +383,7 @@ async def login(_, info, email: str, password: str):
token = await TokenStorage.create_session( token = await TokenStorage.create_session(
user_id=str(valid_author.id), user_id=str(valid_author.id),
username=valid_author.username or valid_author.email or valid_author.slug or "", username=valid_author.username or valid_author.email or valid_author.slug or "",
device_info={"email": valid_author.email} if hasattr(valid_author, "email") else None device_info={"email": valid_author.email} if hasattr(valid_author, "email") else None,
) )
logger.info(f"[auth] login: токен успешно создан, длина: {len(token) if token else 0}") logger.info(f"[auth] login: токен успешно создан, длина: {len(token) if token else 0}")
@ -390,7 +393,7 @@ async def login(_, info, email: str, password: str):
# Устанавливаем httponly cookie различными способами для надежности # Устанавливаем httponly cookie различными способами для надежности
cookie_set = False cookie_set = False
# Метод 1: GraphQL контекст через extensions # Метод 1: GraphQL контекст через extensions
try: try:
if hasattr(info.context, "extensions") and hasattr(info.context.extensions, "set_cookie"): if hasattr(info.context, "extensions") and hasattr(info.context.extensions, "set_cookie"):
@ -406,7 +409,7 @@ async def login(_, info, email: str, password: str):
cookie_set = True cookie_set = True
except Exception as e: except Exception as e:
logger.error(f"[auth] login: Ошибка при установке cookie через extensions: {str(e)}") logger.error(f"[auth] login: Ошибка при установке cookie через extensions: {str(e)}")
# Метод 2: GraphQL контекст через response # Метод 2: GraphQL контекст через response
if not cookie_set: if not cookie_set:
try: try:
@ -423,11 +426,12 @@ async def login(_, info, email: str, password: str):
cookie_set = True cookie_set = True
except Exception as e: except Exception as e:
logger.error(f"[auth] login: Ошибка при установке cookie через response: {str(e)}") logger.error(f"[auth] login: Ошибка при установке cookie через response: {str(e)}")
# Если ни один способ не сработал, создаем response в контексте # Если ни один способ не сработал, создаем response в контексте
if not cookie_set and hasattr(info.context, "request") and not hasattr(info.context, "response"): if not cookie_set and hasattr(info.context, "request") and not hasattr(info.context, "response"):
try: try:
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
response = JSONResponse({}) response = JSONResponse({})
response.set_cookie( response.set_cookie(
key=SESSION_COOKIE_NAME, key=SESSION_COOKIE_NAME,
@ -442,12 +446,12 @@ async def login(_, info, email: str, password: str):
cookie_set = True cookie_set = True
except Exception as e: except Exception as e:
logger.error(f"[auth] login: Ошибка при создании response и установке cookie: {str(e)}") logger.error(f"[auth] login: Ошибка при создании response и установке cookie: {str(e)}")
if not cookie_set: if not cookie_set:
logger.warning(f"[auth] login: Не удалось установить cookie никаким способом") logger.warning(f"[auth] login: Не удалось установить cookie никаким способом")
# Возвращаем успешный результат с данными для клиента # Возвращаем успешный результат с данными для клиента
# Для ответа клиенту используем dict() с параметром access=True, # Для ответа клиенту используем dict() с параметром access=True,
# чтобы получить полный доступ к данным для самого пользователя # чтобы получить полный доступ к данным для самого пользователя
logger.info(f"[auth] login: Успешный вход для {email}") logger.info(f"[auth] login: Успешный вход для {email}")
author_dict = valid_author.dict(access=True) author_dict = valid_author.dict(access=True)
@ -485,7 +489,7 @@ async def is_email_used(_, _info, email):
async def logout_resolver(_, info: GraphQLResolveInfo): async def logout_resolver(_, info: GraphQLResolveInfo):
""" """
Выход из системы через GraphQL с удалением сессии и cookie. Выход из системы через GraphQL с удалением сессии и cookie.
Returns: Returns:
dict: Результат операции выхода dict: Результат операции выхода
""" """
@ -500,7 +504,7 @@ async def logout_resolver(_, info: GraphQLResolveInfo):
success = False success = False
message = "" message = ""
# Если токен найден, отзываем его # Если токен найден, отзываем его
if token: if token:
try: try:
@ -544,12 +548,12 @@ async def logout_resolver(_, info: GraphQLResolveInfo):
async def refresh_token_resolver(_, info: GraphQLResolveInfo): async def refresh_token_resolver(_, info: GraphQLResolveInfo):
""" """
Обновление токена аутентификации через GraphQL. Обновление токена аутентификации через GraphQL.
Returns: Returns:
AuthResult с данными пользователя и обновленным токеном или сообщением об ошибке AuthResult с данными пользователя и обновленным токеном или сообщением об ошибке
""" """
request = info.context["request"] request = info.context["request"]
# Получаем текущий токен из cookie или заголовка # Получаем текущий токен из cookie или заголовка
token = request.cookies.get(SESSION_COOKIE_NAME) token = request.cookies.get(SESSION_COOKIE_NAME)
if not token: if not token:
@ -617,12 +621,7 @@ async def refresh_token_resolver(_, info: GraphQLResolveInfo):
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.info(f"[auth] refresh_token_resolver: Токен успешно обновлен для пользователя {user_id}") logger.info(f"[auth] refresh_token_resolver: Токен успешно обновлен для пользователя {user_id}")
return { return {"success": True, "token": new_token, "author": author, "error": None}
"success": True,
"token": new_token,
"author": author,
"error": None
}
except Exception as e: except Exception as e:
logger.error(f"[auth] refresh_token_resolver: Ошибка при обновлении токена: {e}") logger.error(f"[auth] refresh_token_resolver: Ошибка при обновлении токена: {e}")

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
import time import time
from typing import Optional, List, Dict, Any from typing import Any, Dict, List, Optional
from sqlalchemy import select, text from sqlalchemy import select, text
from auth.orm import Author
from cache.cache import ( from cache.cache import (
cache_author, cache_author,
cached_query, cached_query,
@ -13,7 +14,6 @@ from cache.cache import (
get_cached_follower_topics, get_cached_follower_topics,
invalidate_cache_by_prefix, invalidate_cache_by_prefix,
) )
from auth.orm import Author
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
from services.auth import login_required from services.auth import login_required
from services.db import local_session from services.db import local_session
@ -74,27 +74,26 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c
# Функция для получения авторов из БД # Функция для получения авторов из БД
async def fetch_authors_with_stats(): async def fetch_authors_with_stats():
logger.debug( logger.debug(f"Выполняем запрос на получение авторов со статистикой: limit={limit}, offset={offset}, by={by}")
f"Выполняем запрос на получение авторов со статистикой: limit={limit}, offset={offset}, by={by}"
)
with local_session() as session: with local_session() as session:
# Базовый запрос для получения авторов # Базовый запрос для получения авторов
base_query = select(Author).where(Author.deleted_at.is_(None)) base_query = select(Author).where(Author.deleted_at.is_(None))
# Применяем сортировку # Применяем сортировку
# vars for statistics sorting # vars for statistics sorting
stats_sort_field = None stats_sort_field = None
stats_sort_direction = "desc" stats_sort_direction = "desc"
if by: if by:
if isinstance(by, dict): if isinstance(by, dict):
logger.debug(f"Processing dict-based sorting: {by}") logger.debug(f"Processing dict-based sorting: {by}")
# Обработка словаря параметров сортировки # Обработка словаря параметров сортировки
from sqlalchemy import asc, desc, func from sqlalchemy import asc, desc, func
from orm.shout import ShoutAuthor
from auth.orm import AuthorFollower from auth.orm import AuthorFollower
from orm.shout import ShoutAuthor
# Checking for order field in the dictionary # Checking for order field in the dictionary
if "order" in by: if "order" in by:
@ -135,50 +134,40 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c
# If sorting by statistics, modify the query # If sorting by statistics, modify the query
if stats_sort_field == "shouts": if stats_sort_field == "shouts":
# Sorting by the number of shouts # Sorting by the number of shouts
from sqlalchemy import func, and_ from sqlalchemy import and_, func
from orm.shout import Shout, ShoutAuthor from orm.shout import Shout, ShoutAuthor
subquery = ( subquery = (
select( select(ShoutAuthor.author, func.count(func.distinct(Shout.id)).label("shouts_count"))
ShoutAuthor.author,
func.count(func.distinct(Shout.id)).label("shouts_count")
)
.select_from(ShoutAuthor) .select_from(ShoutAuthor)
.join(Shout, ShoutAuthor.shout == Shout.id) .join(Shout, ShoutAuthor.shout == Shout.id)
.where( .where(and_(Shout.deleted_at.is_(None), Shout.published_at.is_not(None)))
and_(
Shout.deleted_at.is_(None),
Shout.published_at.is_not(None)
)
)
.group_by(ShoutAuthor.author) .group_by(ShoutAuthor.author)
.subquery() .subquery()
) )
base_query = ( base_query = base_query.outerjoin(subquery, Author.id == subquery.c.author).order_by(
base_query desc(func.coalesce(subquery.c.shouts_count, 0))
.outerjoin(subquery, Author.id == subquery.c.author)
.order_by(desc(func.coalesce(subquery.c.shouts_count, 0)))
) )
elif stats_sort_field == "followers": elif stats_sort_field == "followers":
# Sorting by the number of followers # Sorting by the number of followers
from sqlalchemy import func from sqlalchemy import func
from auth.orm import AuthorFollower from auth.orm import AuthorFollower
subquery = ( subquery = (
select( select(
AuthorFollower.author, AuthorFollower.author,
func.count(func.distinct(AuthorFollower.follower)).label("followers_count") func.count(func.distinct(AuthorFollower.follower)).label("followers_count"),
) )
.select_from(AuthorFollower) .select_from(AuthorFollower)
.group_by(AuthorFollower.author) .group_by(AuthorFollower.author)
.subquery() .subquery()
) )
base_query = ( base_query = base_query.outerjoin(subquery, Author.id == subquery.c.author).order_by(
base_query desc(func.coalesce(subquery.c.followers_count, 0))
.outerjoin(subquery, Author.id == subquery.c.author)
.order_by(desc(func.coalesce(subquery.c.followers_count, 0)))
) )
# Применяем лимит и смещение # Применяем лимит и смещение
@ -219,7 +208,7 @@ async def get_authors_with_stats(limit=50, offset=0, by: Optional[str] = None, c
"shouts": shouts_stats.get(author.id, 0), "shouts": shouts_stats.get(author.id, 0),
"followers": followers_stats.get(author.id, 0), "followers": followers_stats.get(author.id, 0),
} }
result.append(author_dict) result.append(author_dict)
# Кешируем каждого автора отдельно для использования в других функциях # Кешируем каждого автора отдельно для использования в других функциях
@ -299,7 +288,7 @@ async def update_author(_, info, profile):
# Кэшируем полную версию для админов # Кэшируем полную версию для админов
author_dict = author_with_stat.dict(access=is_admin) author_dict = author_with_stat.dict(access=is_admin)
asyncio.create_task(cache_author(author_dict)) asyncio.create_task(cache_author(author_dict))
# Возвращаем обычную полную версию, т.к. это владелец # Возвращаем обычную полную версию, т.к. это владелец
return {"error": None, "author": author} return {"error": None, "author": author}
except Exception as exc: except Exception as exc:
@ -328,16 +317,16 @@ async def get_authors_all(_, info):
async def get_author(_, info, slug="", author_id=0): async def get_author(_, info, slug="", author_id=0):
# Получаем ID текущего пользователя и флаг админа из контекста # Получаем ID текущего пользователя и флаг админа из контекста
is_admin = info.context.get("is_admin", False) is_admin = info.context.get("is_admin", False)
author_dict = None author_dict = None
try: try:
author_id = get_author_id_from(slug=slug, user="", author_id=author_id) author_id = get_author_id_from(slug=slug, user="", author_id=author_id)
if not author_id: if not author_id:
raise ValueError("cant find") raise ValueError("cant find")
# Получаем данные автора из кэша (полные данные) # Получаем данные автора из кэша (полные данные)
cached_author = await get_cached_author(int(author_id), get_with_stat) cached_author = await get_cached_author(int(author_id), get_with_stat)
# Применяем фильтрацию на стороне клиента, так как в кэше хранится полная версия # Применяем фильтрацию на стороне клиента, так как в кэше хранится полная версия
if cached_author: if cached_author:
# Создаем объект автора для использования метода dict # Создаем объект автора для использования метода dict
@ -361,7 +350,7 @@ async def get_author(_, info, slug="", author_id=0):
# Кэшируем полные данные для админов # Кэшируем полные данные для админов
original_dict = author_with_stat.dict(access=True) original_dict = author_with_stat.dict(access=True)
asyncio.create_task(cache_author(original_dict)) asyncio.create_task(cache_author(original_dict))
# Возвращаем отфильтрованную версию # Возвращаем отфильтрованную версию
author_dict = author_with_stat.dict(access=is_admin) author_dict = author_with_stat.dict(access=is_admin)
# Добавляем статистику # Добавляем статистику
@ -393,11 +382,12 @@ async def load_authors_by(_, info, by, limit, offset):
# Получаем ID текущего пользователя и флаг админа из контекста # Получаем ID текущего пользователя и флаг админа из контекста
viewer_id = info.context.get("author", {}).get("id") viewer_id = info.context.get("author", {}).get("id")
is_admin = info.context.get("is_admin", False) is_admin = info.context.get("is_admin", False)
# Используем оптимизированную функцию для получения авторов # Используем оптимизированную функцию для получения авторов
return await get_authors_with_stats(limit, offset, by, viewer_id, is_admin) return await get_authors_with_stats(limit, offset, by, viewer_id, is_admin)
except Exception as exc: except Exception as exc:
import traceback import traceback
logger.error(f"{exc}:\n{traceback.format_exc()}") logger.error(f"{exc}:\n{traceback.format_exc()}")
return [] return []
@ -413,7 +403,7 @@ async def load_authors_search(_, info, text: str, limit: int = 10, offset: int =
Returns: Returns:
list: List of authors matching the search criteria list: List of authors matching the search criteria
""" """
# Get author IDs from search engine (already sorted by relevance) # Get author IDs from search engine (already sorted by relevance)
search_results = await search_service.search_authors(text, limit, offset) search_results = await search_service.search_authors(text, limit, offset)
@ -429,13 +419,13 @@ async def load_authors_search(_, info, text: str, limit: int = 10, offset: int =
# Simple query to get authors by IDs - no need for stats here # Simple query to get authors by IDs - no need for stats here
authors_query = select(Author).filter(Author.id.in_(author_ids)) authors_query = select(Author).filter(Author.id.in_(author_ids))
db_authors = session.execute(authors_query).scalars().all() db_authors = session.execute(authors_query).scalars().all()
if not db_authors: if not db_authors:
return [] return []
# Create a dictionary for quick lookup # Create a dictionary for quick lookup
authors_dict = {str(author.id): author for author in db_authors} authors_dict = {str(author.id): author for author in db_authors}
# Keep the order from search results (maintains the relevance sorting) # Keep the order from search results (maintains the relevance sorting)
ordered_authors = [authors_dict[author_id] for author_id in author_ids if author_id in authors_dict] ordered_authors = [authors_dict[author_id] for author_id in author_ids if author_id in authors_dict]
@ -468,7 +458,7 @@ async def get_author_follows(_, info, slug="", user=None, author_id=0):
# Получаем ID текущего пользователя и флаг админа из контекста # Получаем ID текущего пользователя и флаг админа из контекста
viewer_id = info.context.get("author", {}).get("id") viewer_id = info.context.get("author", {}).get("id")
is_admin = info.context.get("is_admin", False) is_admin = info.context.get("is_admin", False)
logger.debug(f"getting follows for @{slug}") logger.debug(f"getting follows for @{slug}")
author_id = get_author_id_from(slug=slug, user=user, author_id=author_id) author_id = get_author_id_from(slug=slug, user=user, author_id=author_id)
if not author_id: if not author_id:
@ -477,7 +467,7 @@ async def get_author_follows(_, info, slug="", user=None, author_id=0):
# Получаем данные из кэша # Получаем данные из кэша
followed_authors_raw = await get_cached_follower_authors(author_id) followed_authors_raw = await get_cached_follower_authors(author_id)
followed_topics = await get_cached_follower_topics(author_id) followed_topics = await get_cached_follower_topics(author_id)
# Фильтруем чувствительные данные авторов # Фильтруем чувствительные данные авторов
followed_authors = [] followed_authors = []
for author_data in followed_authors_raw: for author_data in followed_authors_raw:
@ -517,15 +507,14 @@ async def get_author_follows_authors(_, info, slug="", user=None, author_id=None
# Получаем ID текущего пользователя и флаг админа из контекста # Получаем ID текущего пользователя и флаг админа из контекста
viewer_id = info.context.get("author", {}).get("id") viewer_id = info.context.get("author", {}).get("id")
is_admin = info.context.get("is_admin", False) is_admin = info.context.get("is_admin", False)
logger.debug(f"getting followed authors for @{slug}") logger.debug(f"getting followed authors for @{slug}")
if not author_id: if not author_id:
return [] return []
# Получаем данные из кэша # Получаем данные из кэша
followed_authors_raw = await get_cached_follower_authors(author_id) followed_authors_raw = await get_cached_follower_authors(author_id)
# Фильтруем чувствительные данные авторов # Фильтруем чувствительные данные авторов
followed_authors = [] followed_authors = []
for author_data in followed_authors_raw: for author_data in followed_authors_raw:
@ -540,7 +529,7 @@ async def get_author_follows_authors(_, info, slug="", user=None, author_id=None
# is_admin - булево значение, является ли текущий пользователь админом # is_admin - булево значение, является ли текущий пользователь админом
has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id)) has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id))
followed_authors.append(temp_author.dict(access=has_access)) followed_authors.append(temp_author.dict(access=has_access))
return followed_authors return followed_authors
@ -562,15 +551,15 @@ async def get_author_followers(_, info, slug: str = "", user: str = "", author_i
# Получаем ID текущего пользователя и флаг админа из контекста # Получаем ID текущего пользователя и флаг админа из контекста
viewer_id = info.context.get("author", {}).get("id") viewer_id = info.context.get("author", {}).get("id")
is_admin = info.context.get("is_admin", False) is_admin = info.context.get("is_admin", False)
logger.debug(f"getting followers for author @{slug} or ID:{author_id}") logger.debug(f"getting followers for author @{slug} or ID:{author_id}")
author_id = get_author_id_from(slug=slug, user=user, author_id=author_id) author_id = get_author_id_from(slug=slug, user=user, author_id=author_id)
if not author_id: if not author_id:
return [] return []
# Получаем данные из кэша # Получаем данные из кэша
followers_raw = await get_cached_author_followers(author_id) followers_raw = await get_cached_author_followers(author_id)
# Фильтруем чувствительные данные авторов # Фильтруем чувствительные данные авторов
followers = [] followers = []
for follower_data in followers_raw: for follower_data in followers_raw:
@ -585,5 +574,5 @@ async def get_author_followers(_, info, slug: str = "", user: str = "", author_i
# is_admin - булево значение, является ли текущий пользователь админом # is_admin - булево значение, является ли текущий пользователь админом
has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id)) has_access = is_admin or (viewer_id is not None and str(viewer_id) == str(temp_author.id))
followers.append(temp_author.dict(access=has_access)) followers.append(temp_author.dict(access=has_access))
return followers return followers

View File

@ -72,9 +72,7 @@ def toggle_bookmark_shout(_, info, slug: str) -> CommonResult:
if existing_bookmark: if existing_bookmark:
db.execute( db.execute(
delete(AuthorBookmark).where( delete(AuthorBookmark).where(AuthorBookmark.author == author_id, AuthorBookmark.shout == shout.id)
AuthorBookmark.author == author_id, AuthorBookmark.shout == shout.id
)
) )
result = False result = False
else: else:

View File

@ -74,9 +74,9 @@ async def update_community(_, info, community_data):
if slug: if slug:
with local_session() as session: with local_session() as session:
try: try:
session.query(Community).where( session.query(Community).where(Community.created_by == author_id, Community.slug == slug).update(
Community.created_by == author_id, Community.slug == slug community_data
).update(community_data) )
session.commit() session.commit()
except Exception as e: except Exception as e:
return {"ok": False, "error": str(e)} return {"ok": False, "error": str(e)}
@ -90,9 +90,7 @@ async def delete_community(_, info, slug: str):
author_id = author_dict.get("id") author_id = author_dict.get("id")
with local_session() as session: with local_session() as session:
try: try:
session.query(Community).where( session.query(Community).where(Community.slug == slug, Community.created_by == author_id).delete()
Community.slug == slug, Community.created_by == author_id
).delete()
session.commit() session.commit()
return {"ok": True} return {"ok": True}
except Exception as e: except Exception as e:

View File

@ -1,11 +1,12 @@
import time import time
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from auth.orm import Author
from cache.cache import ( from cache.cache import (
invalidate_shout_related_cache, invalidate_shout_related_cache,
invalidate_shouts_cache, invalidate_shouts_cache,
) )
from auth.orm import Author
from orm.draft import Draft, DraftAuthor, DraftTopic from orm.draft import Draft, DraftAuthor, DraftTopic
from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.shout import Shout, ShoutAuthor, ShoutTopic
from services.auth import login_required from services.auth import login_required
@ -449,15 +450,15 @@ async def publish_draft(_, info, draft_id: int):
# Добавляем темы # Добавляем темы
for topic in draft.topics or []: for topic in draft.topics or []:
st = ShoutTopic( st = ShoutTopic(topic=topic.id, shout=shout.id, main=topic.main if hasattr(topic, "main") else False)
topic=topic.id, shout=shout.id, main=topic.main if hasattr(topic, "main") else False
)
session.add(st) session.add(st)
session.commit() session.commit()
# Инвалидируем кеш # Инвалидируем кеш
cache_keys = [f"shouts:{shout.id}", ] cache_keys = [
f"shouts:{shout.id}",
]
await invalidate_shouts_cache(cache_keys) await invalidate_shouts_cache(cache_keys)
await invalidate_shout_related_cache(shout, author_id) await invalidate_shout_related_cache(shout, author_id)
@ -482,67 +483,59 @@ async def publish_draft(_, info, draft_id: int):
async def unpublish_draft(_, info, draft_id: int): async def unpublish_draft(_, info, draft_id: int):
""" """
Снимает с публикации черновик, обновляя связанный Shout. Снимает с публикации черновик, обновляя связанный Shout.
Args: Args:
draft_id (int): ID черновика, публикацию которого нужно снять draft_id (int): ID черновика, публикацию которого нужно снять
Returns: Returns:
dict: Результат операции с информацией о черновике или сообщением об ошибке dict: Результат операции с информацией о черновике или сообщением об ошибке
""" """
author_dict = info.context.get("author", {}) author_dict = info.context.get("author", {})
author_id = author_dict.get("id") author_id = author_dict.get("id")
if author_id: if author_id:
return {"error": "Author ID is required"} return {"error": "Author ID is required"}
try: try:
with local_session() as session: with local_session() as session:
# Загружаем черновик со связанной публикацией # Загружаем черновик со связанной публикацией
draft = ( draft = (
session.query(Draft) session.query(Draft)
.options( .options(joinedload(Draft.publication), joinedload(Draft.authors), joinedload(Draft.topics))
joinedload(Draft.publication),
joinedload(Draft.authors),
joinedload(Draft.topics)
)
.filter(Draft.id == draft_id) .filter(Draft.id == draft_id)
.first() .first()
) )
if not draft: if not draft:
return {"error": "Draft not found"} return {"error": "Draft not found"}
# Проверяем, есть ли публикация # Проверяем, есть ли публикация
if not draft.publication: if not draft.publication:
return {"error": "This draft is not published yet"} return {"error": "This draft is not published yet"}
shout = draft.publication shout = draft.publication
# Снимаем с публикации # Снимаем с публикации
shout.published_at = None shout.published_at = None
shout.updated_at = int(time.time()) shout.updated_at = int(time.time())
shout.updated_by = author_id shout.updated_by = author_id
session.commit() session.commit()
# Инвалидируем кэш # Инвалидируем кэш
cache_keys = [f"shouts:{shout.id}"] cache_keys = [f"shouts:{shout.id}"]
await invalidate_shouts_cache(cache_keys) await invalidate_shouts_cache(cache_keys)
await invalidate_shout_related_cache(shout, author_id) await invalidate_shout_related_cache(shout, author_id)
# Формируем результат # Формируем результат
draft_dict = draft.dict() draft_dict = draft.dict()
# Добавляем информацию о публикации # Добавляем информацию о публикации
draft_dict["publication"] = { draft_dict["publication"] = {"id": shout.id, "slug": shout.slug, "published_at": None}
"id": shout.id,
"slug": shout.slug,
"published_at": None
}
logger.info(f"Successfully unpublished shout #{shout.id} for draft #{draft_id}") logger.info(f"Successfully unpublished shout #{shout.id} for draft #{draft_id}")
return {"draft": draft_dict} return {"draft": draft_dict}
except Exception as e: except Exception as e:
logger.error(f"Failed to unpublish draft {draft_id}: {e}", exc_info=True) logger.error(f"Failed to unpublish draft {draft_id}: {e}", exc_info=True)
return {"error": f"Failed to unpublish draft: {str(e)}"} return {"error": f"Failed to unpublish draft: {str(e)}"}

View File

@ -5,13 +5,13 @@ from sqlalchemy import and_, desc, select
from sqlalchemy.orm import joinedload, selectinload from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.sql.functions import coalesce from sqlalchemy.sql.functions import coalesce
from auth.orm import Author
from cache.cache import ( from cache.cache import (
cache_author, cache_author,
cache_topic, cache_topic,
invalidate_shout_related_cache, invalidate_shout_related_cache,
invalidate_shouts_cache, invalidate_shouts_cache,
) )
from auth.orm import Author
from orm.draft import Draft from orm.draft import Draft
from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.shout import Shout, ShoutAuthor, ShoutTopic
from orm.topic import Topic from orm.topic import Topic
@ -179,9 +179,7 @@ async def create_shout(_, info, inp):
lead = inp.get("lead", "") lead = inp.get("lead", "")
body_text = extract_text(body) body_text = extract_text(body)
lead_text = extract_text(lead) lead_text = extract_text(lead)
seo = inp.get( seo = inp.get("seo", lead_text.strip() or body_text.strip()[:300].split(". ")[:-1].join(". "))
"seo", lead_text.strip() or body_text.strip()[:300].split(". ")[:-1].join(". ")
)
new_shout = Shout( new_shout = Shout(
slug=slug, slug=slug,
body=body, body=body,
@ -278,9 +276,7 @@ def patch_main_topic(session, main_topic_slug, shout):
with session.begin(): with session.begin():
# Получаем текущий главный топик # Получаем текущий главный топик
old_main = ( old_main = (
session.query(ShoutTopic) session.query(ShoutTopic).filter(and_(ShoutTopic.shout == shout.id, ShoutTopic.main.is_(True))).first()
.filter(and_(ShoutTopic.shout == shout.id, ShoutTopic.main.is_(True)))
.first()
) )
if old_main: if old_main:
logger.info(f"Found current main topic: {old_main.topic.slug}") logger.info(f"Found current main topic: {old_main.topic.slug}")
@ -314,9 +310,7 @@ def patch_main_topic(session, main_topic_slug, shout):
session.flush() session.flush()
logger.info(f"Main topic updated for shout#{shout.id}") logger.info(f"Main topic updated for shout#{shout.id}")
else: else:
logger.warning( logger.warning(f"No changes needed for main topic (old={old_main is not None}, new={new_main is not None})")
f"No changes needed for main topic (old={old_main is not None}, new={new_main is not None})"
)
def patch_topics(session, shout, topics_input): def patch_topics(session, shout, topics_input):
@ -410,9 +404,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False):
logger.info(f"Processing update for shout#{shout_id} by author #{author_id}") logger.info(f"Processing update for shout#{shout_id} by author #{author_id}")
shout_by_id = ( shout_by_id = (
session.query(Shout) session.query(Shout)
.options( .options(joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors))
joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors)
)
.filter(Shout.id == shout_id) .filter(Shout.id == shout_id)
.first() .first()
) )
@ -441,10 +433,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False):
shout_input["slug"] = slug shout_input["slug"] = slug
logger.info(f"shout#{shout_id} slug patched") logger.info(f"shout#{shout_id} slug patched")
if ( if filter(lambda x: x.id == author_id, [x for x in shout_by_id.authors]) or "editor" in roles:
filter(lambda x: x.id == author_id, [x for x in shout_by_id.authors])
or "editor" in roles
):
logger.info(f"Author #{author_id} has permission to edit shout#{shout_id}") logger.info(f"Author #{author_id} has permission to edit shout#{shout_id}")
# topics patch # topics patch
@ -558,9 +547,7 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False):
# Получаем полные данные шаута со связями # Получаем полные данные шаута со связями
shout_with_relations = ( shout_with_relations = (
session.query(Shout) session.query(Shout)
.options( .options(joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors))
joinedload(Shout.topics).joinedload(ShoutTopic.topic), joinedload(Shout.authors)
)
.filter(Shout.id == shout_id) .filter(Shout.id == shout_id)
.first() .first()
) )

View File

@ -71,9 +71,7 @@ def shouts_by_follower(info, follower_id: int, options):
q = query_with_stat(info) q = query_with_stat(info)
reader_followed_authors = select(AuthorFollower.author).where(AuthorFollower.follower == follower_id) reader_followed_authors = select(AuthorFollower.author).where(AuthorFollower.follower == follower_id)
reader_followed_topics = select(TopicFollower.topic).where(TopicFollower.follower == follower_id) reader_followed_topics = select(TopicFollower.topic).where(TopicFollower.follower == follower_id)
reader_followed_shouts = select(ShoutReactionsFollower.shout).where( reader_followed_shouts = select(ShoutReactionsFollower.shout).where(ShoutReactionsFollower.follower == follower_id)
ShoutReactionsFollower.follower == follower_id
)
followed_subquery = ( followed_subquery = (
select(Shout.id) select(Shout.id)
.join(ShoutAuthor, ShoutAuthor.shout == Shout.id) .join(ShoutAuthor, ShoutAuthor.shout == Shout.id)
@ -142,9 +140,7 @@ async def load_shouts_authored_by(_, info, slug: str, options) -> List[Shout]:
q = ( q = (
query_with_stat(info) query_with_stat(info)
if has_field(info, "stat") if has_field(info, "stat")
else select(Shout).filter( else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)))
and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))
)
) )
q = q.filter(Shout.authors.any(id=author_id)) q = q.filter(Shout.authors.any(id=author_id))
q, limit, offset = apply_options(q, options, author_id) q, limit, offset = apply_options(q, options, author_id)
@ -173,9 +169,7 @@ async def load_shouts_with_topic(_, info, slug: str, options) -> List[Shout]:
q = ( q = (
query_with_stat(info) query_with_stat(info)
if has_field(info, "stat") if has_field(info, "stat")
else select(Shout).filter( else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)))
and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))
)
) )
q = q.filter(Shout.topics.any(id=topic_id)) q = q.filter(Shout.topics.any(id=topic_id))
q, limit, offset = apply_options(q, options) q, limit, offset = apply_options(q, options)

View File

@ -4,13 +4,13 @@ from graphql import GraphQLError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.sql import and_ from sqlalchemy.sql import and_
from auth.orm import Author, AuthorFollower
from cache.cache import ( from cache.cache import (
cache_author, cache_author,
cache_topic, cache_topic,
get_cached_follower_authors, get_cached_follower_authors,
get_cached_follower_topics, get_cached_follower_topics,
) )
from auth.orm import Author, AuthorFollower
from orm.community import Community, CommunityFollower from orm.community import Community, CommunityFollower
from orm.reaction import Reaction from orm.reaction import Reaction
from orm.shout import Shout, ShoutReactionsFollower from orm.shout import Shout, ShoutReactionsFollower
@ -65,14 +65,14 @@ async def follow(_, info, what, slug="", entity_id=0):
return {"error": f"{what.lower()} not found"} return {"error": f"{what.lower()} not found"}
if not entity_id and entity: if not entity_id and entity:
entity_id = entity.id entity_id = entity.id
# Если это автор, учитываем фильтрацию данных # Если это автор, учитываем фильтрацию данных
if what == "AUTHOR": if what == "AUTHOR":
# Полная версия для кэширования # Полная версия для кэширования
entity_dict = entity.dict(is_admin=True) entity_dict = entity.dict(is_admin=True)
else: else:
entity_dict = entity.dict() entity_dict = entity.dict()
logger.debug(f"entity_id: {entity_id}, entity_dict: {entity_dict}") logger.debug(f"entity_id: {entity_id}, entity_dict: {entity_dict}")
if entity_id: if entity_id:
@ -87,9 +87,7 @@ async def follow(_, info, what, slug="", entity_id=0):
.first() .first()
) )
if existing_sub: if existing_sub:
logger.info( logger.info(f"Пользователь {follower_id} уже подписан на {what.lower()} с ID {entity_id}")
f"Пользователь {follower_id} уже подписан на {what.lower()} с ID {entity_id}"
)
else: else:
logger.debug("Добавление новой записи в базу данных") logger.debug("Добавление новой записи в базу данных")
sub = follower_class(follower=follower_id, **{entity_type: entity_id}) sub = follower_class(follower=follower_id, **{entity_type: entity_id})
@ -105,12 +103,12 @@ async def follow(_, info, what, slug="", entity_id=0):
if get_cached_follows_method: if get_cached_follows_method:
logger.debug("Получение подписок из кэша") logger.debug("Получение подписок из кэша")
existing_follows = await get_cached_follows_method(follower_id) existing_follows = await get_cached_follows_method(follower_id)
# Если это авторы, получаем безопасную версию # Если это авторы, получаем безопасную версию
if what == "AUTHOR": if what == "AUTHOR":
# Получаем ID текущего пользователя и фильтруем данные # Получаем ID текущего пользователя и фильтруем данные
follows_filtered = [] follows_filtered = []
for author_data in existing_follows: for author_data in existing_follows:
# Создаем объект автора для использования метода dict # Создаем объект автора для использования метода dict
temp_author = Author() temp_author = Author()
@ -119,7 +117,7 @@ async def follow(_, info, what, slug="", entity_id=0):
setattr(temp_author, key, value) setattr(temp_author, key, value)
# Добавляем отфильтрованную версию # Добавляем отфильтрованную версию
follows_filtered.append(temp_author.dict(viewer_id, False)) follows_filtered.append(temp_author.dict(viewer_id, False))
if not existing_sub: if not existing_sub:
# Создаем объект автора для entity_dict # Создаем объект автора для entity_dict
temp_author = Author() temp_author = Author()
@ -132,7 +130,7 @@ async def follow(_, info, what, slug="", entity_id=0):
follows = follows_filtered follows = follows_filtered
else: else:
follows = [*existing_follows, entity_dict] if not existing_sub else existing_follows follows = [*existing_follows, entity_dict] if not existing_sub else existing_follows
logger.debug("Обновлен список подписок") logger.debug("Обновлен список подписок")
if what == "AUTHOR" and not existing_sub: if what == "AUTHOR" and not existing_sub:
@ -214,20 +212,20 @@ async def unfollow(_, info, what, slug="", entity_id=0):
await cache_method(entity.dict(is_admin=True)) await cache_method(entity.dict(is_admin=True))
else: else:
await cache_method(entity.dict()) await cache_method(entity.dict())
if get_cached_follows_method: if get_cached_follows_method:
logger.debug("Получение подписок из кэша") logger.debug("Получение подписок из кэша")
existing_follows = await get_cached_follows_method(follower_id) existing_follows = await get_cached_follows_method(follower_id)
# Если это авторы, получаем безопасную версию # Если это авторы, получаем безопасную версию
if what == "AUTHOR": if what == "AUTHOR":
# Получаем ID текущего пользователя и фильтруем данные # Получаем ID текущего пользователя и фильтруем данные
follows_filtered = [] follows_filtered = []
for author_data in existing_follows: for author_data in existing_follows:
if author_data["id"] == entity_id: if author_data["id"] == entity_id:
continue continue
# Создаем объект автора для использования метода dict # Создаем объект автора для использования метода dict
temp_author = Author() temp_author = Author()
for key, value in author_data.items(): for key, value in author_data.items():
@ -235,11 +233,11 @@ async def unfollow(_, info, what, slug="", entity_id=0):
setattr(temp_author, key, value) setattr(temp_author, key, value)
# Добавляем отфильтрованную версию # Добавляем отфильтрованную версию
follows_filtered.append(temp_author.dict(viewer_id, False)) follows_filtered.append(temp_author.dict(viewer_id, False))
follows = follows_filtered follows = follows_filtered
else: else:
follows = [item for item in existing_follows if item["id"] != entity_id] follows = [item for item in existing_follows if item["id"] != entity_id]
logger.debug("Обновлен список подписок") logger.debug("Обновлен список подписок")
if what == "AUTHOR": if what == "AUTHOR":

View File

@ -66,9 +66,7 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[
return total, unread, notifications return total, unread, notifications
def group_notification( def group_notification(thread, authors=None, shout=None, reactions=None, entity="follower", action="follow"):
thread, authors=None, shout=None, reactions=None, entity="follower", action="follow"
):
reactions = reactions or [] reactions = reactions or []
authors = authors or [] authors = authors or []
return { return {

View File

@ -14,11 +14,7 @@ def handle_proposing(kind: ReactionKind, reply_to: int, shout_id: int):
session.query(Reaction).filter(Reaction.id == reply_to, Reaction.shout == shout_id).first() session.query(Reaction).filter(Reaction.id == reply_to, Reaction.shout == shout_id).first()
) )
if ( if replied_reaction and replied_reaction.kind is ReactionKind.PROPOSE.value and replied_reaction.quote:
replied_reaction
and replied_reaction.kind is ReactionKind.PROPOSE.value
and replied_reaction.quote
):
# patch all the proposals' quotes # patch all the proposals' quotes
proposals = ( proposals = (
session.query(Reaction) session.query(Reaction)

View File

@ -186,9 +186,7 @@ def count_author_shouts_rating(session, author_id) -> int:
def get_author_rating_old(session, author: Author): def get_author_rating_old(session, author: Author):
likes_count = ( likes_count = (
session.query(AuthorRating) session.query(AuthorRating).filter(and_(AuthorRating.author == author.id, AuthorRating.plus.is_(True))).count()
.filter(and_(AuthorRating.author == author.id, AuthorRating.plus.is_(True)))
.count()
) )
dislikes_count = ( dislikes_count = (
session.query(AuthorRating) session.query(AuthorRating)

View File

@ -334,9 +334,7 @@ async def create_reaction(_, info, reaction):
with local_session() as session: with local_session() as session:
authors = session.query(ShoutAuthor.author).filter(ShoutAuthor.shout == shout_id).scalar() authors = session.query(ShoutAuthor.author).filter(ShoutAuthor.shout == shout_id).scalar()
is_author = ( is_author = (
bool(list(filter(lambda x: x == int(author_id), authors))) bool(list(filter(lambda x: x == int(author_id), authors))) if isinstance(authors, list) else False
if isinstance(authors, list)
else False
) )
reaction_input["created_by"] = author_id reaction_input["created_by"] = author_id
kind = reaction_input.get("kind") kind = reaction_input.get("kind")

View File

@ -138,9 +138,7 @@ def query_with_stat(info):
select( select(
ShoutTopic.shout, ShoutTopic.shout,
json_array_builder( json_array_builder(
json_builder( json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug, "is_main", ShoutTopic.main)
"id", Topic.id, "title", Topic.title, "slug", Topic.slug, "is_main", ShoutTopic.main
)
).label("topics"), ).label("topics"),
) )
.outerjoin(Topic, ShoutTopic.topic == Topic.id) .outerjoin(Topic, ShoutTopic.topic == Topic.id)
@ -227,7 +225,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
"slug": a.slug, "slug": a.slug,
"pic": a.pic, "pic": a.pic,
} }
# Обработка поля updated_by # Обработка поля updated_by
if has_field(info, "updated_by"): if has_field(info, "updated_by"):
if shout_dict.get("updated_by"): if shout_dict.get("updated_by"):
@ -246,7 +244,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
else: else:
# Если updated_by не указан, устанавливаем поле в null # Если updated_by не указан, устанавливаем поле в null
shout_dict["updated_by"] = None shout_dict["updated_by"] = None
# Обработка поля deleted_by # Обработка поля deleted_by
if has_field(info, "deleted_by"): if has_field(info, "deleted_by"):
if shout_dict.get("deleted_by"): if shout_dict.get("deleted_by"):
@ -287,9 +285,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
if hasattr(row, "main_topic"): if hasattr(row, "main_topic"):
# logger.debug(f"Raw main_topic for shout#{shout_id}: {row.main_topic}") # logger.debug(f"Raw main_topic for shout#{shout_id}: {row.main_topic}")
main_topic = ( main_topic = (
orjson.loads(row.main_topic) orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic
if isinstance(row.main_topic, str)
else row.main_topic
) )
# logger.debug(f"Parsed main_topic for shout#{shout_id}: {main_topic}") # logger.debug(f"Parsed main_topic for shout#{shout_id}: {main_topic}")
@ -325,9 +321,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
media_data = orjson.loads(media_data) media_data = orjson.loads(media_data)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
media_data = [] media_data = []
shout_dict["media"] = ( shout_dict["media"] = [media_data] if isinstance(media_data, dict) else media_data
[media_data] if isinstance(media_data, dict) else media_data
)
shouts.append(shout_dict) shouts.append(shout_dict)
@ -415,9 +409,7 @@ def apply_sorting(q, options):
""" """
order_str = options.get("order_by") order_str = options.get("order_by")
if order_str in ["rating", "comments_count", "last_commented_at"]: if order_str in ["rating", "comments_count", "last_commented_at"]:
query_order_by = ( query_order_by = desc(text(order_str)) if options.get("order_by_desc", True) else asc(text(order_str))
desc(text(order_str)) if options.get("order_by_desc", True) else asc(text(order_str))
)
q = q.distinct(text(order_str), Shout.id).order_by( # DISTINCT ON включает поле сортировки q = q.distinct(text(order_str), Shout.id).order_by( # DISTINCT ON включает поле сортировки
nulls_last(query_order_by), Shout.id nulls_last(query_order_by), Shout.id
) )
@ -513,15 +505,11 @@ async def load_shouts_unrated(_, info, options):
q = select(Shout).where(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))) q = select(Shout).where(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)))
q = q.join(Author, Author.id == Shout.created_by) q = q.join(Author, Author.id == Shout.created_by)
q = q.add_columns( q = q.add_columns(
json_builder("id", Author.id, "name", Author.name, "slug", Author.slug, "pic", Author.pic).label( json_builder("id", Author.id, "name", Author.name, "slug", Author.slug, "pic", Author.pic).label("main_author")
"main_author"
)
) )
q = q.join(ShoutTopic, and_(ShoutTopic.shout == Shout.id, ShoutTopic.main.is_(True))) q = q.join(ShoutTopic, and_(ShoutTopic.shout == Shout.id, ShoutTopic.main.is_(True)))
q = q.join(Topic, Topic.id == ShoutTopic.topic) q = q.join(Topic, Topic.id == ShoutTopic.topic)
q = q.add_columns( q = q.add_columns(json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug).label("main_topic"))
json_builder("id", Topic.id, "title", Topic.title, "slug", Topic.slug).label("main_topic")
)
q = q.where(Shout.id.not_in(rated_shouts)) q = q.where(Shout.id.not_in(rated_shouts))
q = q.order_by(func.random()) q = q.order_by(func.random())

View File

@ -3,8 +3,8 @@ import asyncio
from sqlalchemy import and_, distinct, func, join, select from sqlalchemy import and_, distinct, func, join, select
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from cache.cache import cache_author
from auth.orm import Author, AuthorFollower from auth.orm import Author, AuthorFollower
from cache.cache import cache_author
from orm.reaction import Reaction, ReactionKind from orm.reaction import Reaction, ReactionKind
from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.shout import Shout, ShoutAuthor, ShoutTopic
from orm.topic import Topic, TopicFollower from orm.topic import Topic, TopicFollower
@ -177,9 +177,7 @@ def get_topic_comments_stat(topic_id: int) -> int:
.subquery() .subquery()
) )
# Запрос для суммирования количества комментариев по теме # Запрос для суммирования количества комментариев по теме
q = select(func.coalesce(func.sum(sub_comments.c.comments_count), 0)).filter( q = select(func.coalesce(func.sum(sub_comments.c.comments_count), 0)).filter(ShoutTopic.topic == topic_id)
ShoutTopic.topic == topic_id
)
q = q.outerjoin(sub_comments, ShoutTopic.shout == sub_comments.c.shout_id) q = q.outerjoin(sub_comments, ShoutTopic.shout == sub_comments.c.shout_id)
with local_session() as session: with local_session() as session:
result = session.execute(q).first() result = session.execute(q).first()
@ -239,9 +237,7 @@ def get_author_followers_stat(author_id: int) -> int:
:return: Количество уникальных подписчиков автора. :return: Количество уникальных подписчиков автора.
""" """
aliased_followers = aliased(AuthorFollower) aliased_followers = aliased(AuthorFollower)
q = select(func.count(distinct(aliased_followers.follower))).filter( q = select(func.count(distinct(aliased_followers.follower))).filter(aliased_followers.author == author_id)
aliased_followers.author == author_id
)
with local_session() as session: with local_session() as session:
result = session.execute(q).first() result = session.execute(q).first()
return result[0] if result else 0 return result[0] if result else 0
@ -293,9 +289,7 @@ def get_with_stat(q):
stat["shouts"] = cols[1] # Статистика по публикациям stat["shouts"] = cols[1] # Статистика по публикациям
stat["followers"] = cols[2] # Статистика по подписчикам stat["followers"] = cols[2] # Статистика по подписчикам
if is_author: if is_author:
stat["authors"] = get_author_authors_stat( stat["authors"] = get_author_authors_stat(entity.id) # Статистика по подпискам на авторов
entity.id
) # Статистика по подпискам на авторов
stat["comments"] = get_author_comments_stat(entity.id) # Статистика по комментариям stat["comments"] = get_author_comments_stat(entity.id) # Статистика по комментариям
else: else:
stat["authors"] = get_topic_authors_stat(entity.id) # Статистика по авторам темы stat["authors"] = get_topic_authors_stat(entity.id) # Статистика по авторам темы

View File

@ -1,5 +1,6 @@
from sqlalchemy import desc, select, text from sqlalchemy import desc, select, text
from auth.orm import Author
from cache.cache import ( from cache.cache import (
cache_topic, cache_topic,
cached_query, cached_query,
@ -8,9 +9,8 @@ from cache.cache import (
get_cached_topic_followers, get_cached_topic_followers,
invalidate_cache_by_prefix, invalidate_cache_by_prefix,
) )
from auth.orm import Author
from orm.topic import Topic
from orm.reaction import ReactionKind from orm.reaction import ReactionKind
from orm.topic import Topic
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
from services.auth import login_required from services.auth import login_required
from services.db import local_session from services.db import local_session

View File

@ -1,16 +1,16 @@
from functools import wraps from functools import wraps
from typing import Tuple from typing import Tuple
from sqlalchemy import exc
from starlette.requests import Request from starlette.requests import Request
from auth.internal import verify_internal_auth
from auth.orm import Author, Role
from cache.cache import get_cached_author_by_id from cache.cache import get_cached_author_by_id
from resolvers.stat import get_with_stat from resolvers.stat import get_with_stat
from utils.logger import root_logger as logger
from auth.internal import verify_internal_auth
from sqlalchemy import exc
from services.db import local_session from services.db import local_session
from auth.orm import Author, Role
from settings import SESSION_TOKEN_HEADER from settings import SESSION_TOKEN_HEADER
from utils.logger import root_logger as logger
# Список разрешенных заголовков # Список разрешенных заголовков
ALLOWED_HEADERS = ["Authorization", "Content-Type"] ALLOWED_HEADERS = ["Authorization", "Content-Type"]
@ -31,21 +31,21 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]:
- is_admin: bool - Флаг наличия у пользователя административных прав - is_admin: bool - Флаг наличия у пользователя административных прав
""" """
logger.debug(f"[check_auth] Проверка авторизации...") logger.debug(f"[check_auth] Проверка авторизации...")
# Получаем заголовок авторизации # Получаем заголовок авторизации
token = None token = None
# Проверяем заголовок с учетом регистра # Проверяем заголовок с учетом регистра
headers_dict = dict(req.headers.items()) headers_dict = dict(req.headers.items())
logger.debug(f"[check_auth] Все заголовки: {headers_dict}") logger.debug(f"[check_auth] Все заголовки: {headers_dict}")
# Ищем заголовок Authorization независимо от регистра # Ищем заголовок Authorization независимо от регистра
for header_name, header_value in headers_dict.items(): for header_name, header_value in headers_dict.items():
if header_name.lower() == SESSION_TOKEN_HEADER.lower(): if header_name.lower() == SESSION_TOKEN_HEADER.lower():
token = header_value token = header_value
logger.debug(f"[check_auth] Найден заголовок {header_name}: {token[:10]}...") logger.debug(f"[check_auth] Найден заголовок {header_name}: {token[:10]}...")
break break
if not token: if not token:
logger.debug(f"[check_auth] Токен не найден в заголовках") logger.debug(f"[check_auth] Токен не найден в заголовках")
return "", [], False return "", [], False
@ -57,8 +57,10 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]:
# Проверяем авторизацию внутренним механизмом # Проверяем авторизацию внутренним механизмом
logger.debug("[check_auth] Вызов verify_internal_auth...") logger.debug("[check_auth] Вызов verify_internal_auth...")
user_id, user_roles, is_admin = await verify_internal_auth(token) user_id, user_roles, is_admin = await verify_internal_auth(token)
logger.debug(f"[check_auth] Результат verify_internal_auth: user_id={user_id}, roles={user_roles}, is_admin={is_admin}") logger.debug(
f"[check_auth] Результат verify_internal_auth: user_id={user_id}, roles={user_roles}, is_admin={is_admin}"
)
# Если в ролях нет админа, но есть ID - проверяем в БД # Если в ролях нет админа, но есть ID - проверяем в БД
if user_id and not is_admin: if user_id and not is_admin:
try: try:
@ -71,16 +73,19 @@ async def check_auth(req: Request) -> Tuple[str, list[str], bool]:
else: else:
# Проверяем наличие админских прав через БД # Проверяем наличие админских прав через БД
from auth.orm import AuthorRole from auth.orm import AuthorRole
admin_role = session.query(AuthorRole).filter(
AuthorRole.author == user_id_int, admin_role = (
AuthorRole.role.in_(["admin", "super"]) session.query(AuthorRole)
).first() .filter(AuthorRole.author == user_id_int, AuthorRole.role.in_(["admin", "super"]))
.first()
)
is_admin = admin_role is not None is_admin = admin_role is not None
except Exception as e: except Exception as e:
logger.error(f"Ошибка при проверке прав администратора: {e}") logger.error(f"Ошибка при проверке прав администратора: {e}")
return user_id, user_roles, is_admin return user_id, user_roles, is_admin
async def add_user_role(user_id: str, roles: list[str] = None): async def add_user_role(user_id: str, roles: list[str] = None):
""" """
Добавление ролей пользователю в локальной БД. Добавление ролей пользователю в локальной БД.
@ -131,32 +136,32 @@ def login_required(f):
info = args[1] info = args[1]
req = info.context.get("request") req = info.context.get("request")
logger.debug(f"[login_required] Проверка авторизации для запроса: {req.method} {req.url.path}") logger.debug(f"[login_required] Проверка авторизации для запроса: {req.method} {req.url.path}")
logger.debug(f"[login_required] Заголовки: {req.headers}") logger.debug(f"[login_required] Заголовки: {req.headers}")
user_id, user_roles, is_admin = await check_auth(req) user_id, user_roles, is_admin = await check_auth(req)
if not user_id: if not user_id:
logger.debug(f"[login_required] Пользователь не авторизован, {dict(req)}, {info}") logger.debug(f"[login_required] Пользователь не авторизован, {dict(req)}, {info}")
raise GraphQLError("Требуется авторизация") raise GraphQLError("Требуется авторизация")
# Проверяем наличие роли reader # Проверяем наличие роли reader
if 'reader' not in user_roles: if "reader" not in user_roles:
logger.error(f"Пользователь {user_id} не имеет роли 'reader'") logger.error(f"Пользователь {user_id} не имеет роли 'reader'")
raise GraphQLError("У вас нет необходимых прав для доступа") raise GraphQLError("У вас нет необходимых прав для доступа")
logger.info(f"Авторизован пользователь {user_id} с ролями: {user_roles}") logger.info(f"Авторизован пользователь {user_id} с ролями: {user_roles}")
info.context["roles"] = user_roles info.context["roles"] = user_roles
# Проверяем права администратора # Проверяем права администратора
info.context["is_admin"] = is_admin info.context["is_admin"] = is_admin
author = await get_cached_author_by_id(user_id, get_with_stat) author = await get_cached_author_by_id(user_id, get_with_stat)
if not author: if not author:
logger.error(f"Профиль автора не найден для пользователя {user_id}") logger.error(f"Профиль автора не найден для пользователя {user_id}")
info.context["author"] = author info.context["author"] = author
return await f(*args, **kwargs) return await f(*args, **kwargs)
return decorated_function return decorated_function
@ -177,7 +182,7 @@ def login_accepted(f):
if user_id and user_roles: if user_id and user_roles:
logger.info(f"login_accepted: Пользователь авторизован: {user_id} с ролями {user_roles}") logger.info(f"login_accepted: Пользователь авторизован: {user_id} с ролями {user_roles}")
info.context["roles"] = user_roles info.context["roles"] = user_roles
# Проверяем права администратора # Проверяем права администратора
info.context["is_admin"] = is_admin info.context["is_admin"] = is_admin

View File

@ -200,9 +200,7 @@ class Base(declarative_base()):
data[column_name] = value data[column_name] = value
else: else:
# Пропускаем атрибут, если его нет в объекте (может быть добавлен после миграции) # Пропускаем атрибут, если его нет в объекте (может быть добавлен после миграции)
logger.debug( logger.debug(f"Skipping missing attribute '{column_name}' for {self.__class__.__name__}")
f"Skipping missing attribute '{column_name}' for {self.__class__.__name__}"
)
except AttributeError as e: except AttributeError as e:
logger.warning(f"Attribute error for column '{column_name}': {e}") logger.warning(f"Attribute error for column '{column_name}': {e}")
# Добавляем синтетическое поле .stat если оно существует # Добавляем синтетическое поле .stat если оно существует
@ -223,9 +221,7 @@ class Base(declarative_base()):
# Функция для вывода полного трейсбека при предупреждениях # Функция для вывода полного трейсбека при предупреждениях
def warning_with_traceback( def warning_with_traceback(message: Warning | str, category, filename: str, lineno: int, file=None, line=None):
message: Warning | str, category, filename: str, lineno: int, file=None, line=None
):
tb = traceback.format_stack() tb = traceback.format_stack()
tb_str = "".join(tb) tb_str = "".join(tb)
return f"{message} ({filename}, {lineno}): {category.__name__}\n{tb_str}" return f"{message} ({filename}, {lineno}): {category.__name__}\n{tb_str}"
@ -302,22 +298,22 @@ json_builder, json_array_builder, json_cast = get_json_builder()
# Fetch all shouts, with authors preloaded # Fetch all shouts, with authors preloaded
# This function is used for search indexing # This function is used for search indexing
async def fetch_all_shouts(session=None): async def fetch_all_shouts(session=None):
"""Fetch all published shouts for search indexing with authors preloaded""" """Fetch all published shouts for search indexing with authors preloaded"""
from orm.shout import Shout from orm.shout import Shout
close_session = False close_session = False
if session is None: if session is None:
session = local_session() session = local_session()
close_session = True close_session = True
try: try:
# Fetch only published and non-deleted shouts with authors preloaded # Fetch only published and non-deleted shouts with authors preloaded
query = session.query(Shout).options( query = (
joinedload(Shout.authors) session.query(Shout)
).filter( .options(joinedload(Shout.authors))
Shout.published_at.is_not(None), .filter(Shout.published_at.is_not(None), Shout.deleted_at.is_(None))
Shout.deleted_at.is_(None)
) )
shouts = query.all() shouts = query.all()
return shouts return shouts
@ -326,4 +322,4 @@ async def fetch_all_shouts(session=None):
return [] return []
finally: finally:
if close_session: if close_session:
session.close() session.close()

View File

@ -1,9 +1,11 @@
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
import os import os
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set
from redis import Redis from redis import Redis
from settings import REDIS_URL, ROOT_DIR from settings import REDIS_URL, ROOT_DIR
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
@ -31,12 +33,37 @@ class EnvManager:
# Стандартные переменные окружения, которые следует исключить # Стандартные переменные окружения, которые следует исключить
EXCLUDED_ENV_VARS: Set[str] = { EXCLUDED_ENV_VARS: Set[str] = {
"PATH", "SHELL", "USER", "HOME", "PWD", "TERM", "LANG", "PATH",
"PYTHONPATH", "_", "TMPDIR", "TERM_PROGRAM", "TERM_SESSION_ID", "SHELL",
"XPC_SERVICE_NAME", "XPC_FLAGS", "SHLVL", "SECURITYSESSIONID", "USER",
"LOGNAME", "OLDPWD", "ZSH", "PAGER", "LESS", "LC_CTYPE", "LSCOLORS", "HOME",
"SSH_AUTH_SOCK", "DISPLAY", "COLORTERM", "EDITOR", "VISUAL", "PWD",
"PYTHONDONTWRITEBYTECODE", "VIRTUAL_ENV", "PYTHONUNBUFFERED" "TERM",
"LANG",
"PYTHONPATH",
"_",
"TMPDIR",
"TERM_PROGRAM",
"TERM_SESSION_ID",
"XPC_SERVICE_NAME",
"XPC_FLAGS",
"SHLVL",
"SECURITYSESSIONID",
"LOGNAME",
"OLDPWD",
"ZSH",
"PAGER",
"LESS",
"LC_CTYPE",
"LSCOLORS",
"SSH_AUTH_SOCK",
"DISPLAY",
"COLORTERM",
"EDITOR",
"VISUAL",
"PYTHONDONTWRITEBYTECODE",
"VIRTUAL_ENV",
"PYTHONUNBUFFERED",
} }
# Секции для группировки переменных # Секции для группировки переменных
@ -44,57 +71,67 @@ class EnvManager:
"AUTH": { "AUTH": {
"pattern": r"^(JWT|AUTH|SESSION|OAUTH|GITHUB|GOOGLE|FACEBOOK)_", "pattern": r"^(JWT|AUTH|SESSION|OAUTH|GITHUB|GOOGLE|FACEBOOK)_",
"name": "Авторизация", "name": "Авторизация",
"description": "Настройки системы авторизации" "description": "Настройки системы авторизации",
}, },
"DATABASE": { "DATABASE": {
"pattern": r"^(DB|DATABASE|POSTGRES|MYSQL|SQL)_", "pattern": r"^(DB|DATABASE|POSTGRES|MYSQL|SQL)_",
"name": "База данных", "name": "База данных",
"description": "Настройки подключения к базам данных" "description": "Настройки подключения к базам данных",
}, },
"CACHE": { "CACHE": {
"pattern": r"^(REDIS|CACHE|MEMCACHED)_", "pattern": r"^(REDIS|CACHE|MEMCACHED)_",
"name": "Кэширование", "name": "Кэширование",
"description": "Настройки систем кэширования" "description": "Настройки систем кэширования",
}, },
"SEARCH": { "SEARCH": {
"pattern": r"^(ELASTIC|SEARCH|OPENSEARCH)_", "pattern": r"^(ELASTIC|SEARCH|OPENSEARCH)_",
"name": "Поиск", "name": "Поиск",
"description": "Настройки поисковых систем" "description": "Настройки поисковых систем",
}, },
"APP": { "APP": {
"pattern": r"^(APP|PORT|HOST|DEBUG|DOMAIN|ENVIRONMENT|ENV|FRONTEND)_", "pattern": r"^(APP|PORT|HOST|DEBUG|DOMAIN|ENVIRONMENT|ENV|FRONTEND)_",
"name": "Общие настройки", "name": "Общие настройки",
"description": "Общие настройки приложения" "description": "Общие настройки приложения",
}, },
"LOGGING": { "LOGGING": {
"pattern": r"^(LOG|LOGGING|SENTRY|GLITCH|GLITCHTIP)_", "pattern": r"^(LOG|LOGGING|SENTRY|GLITCH|GLITCHTIP)_",
"name": "Мониторинг", "name": "Мониторинг",
"description": "Настройки логирования и мониторинга" "description": "Настройки логирования и мониторинга",
}, },
"EMAIL": { "EMAIL": {
"pattern": r"^(MAIL|EMAIL|SMTP|IMAP|POP3|POST)_", "pattern": r"^(MAIL|EMAIL|SMTP|IMAP|POP3|POST)_",
"name": "Электронная почта", "name": "Электронная почта",
"description": "Настройки отправки электронной почты" "description": "Настройки отправки электронной почты",
}, },
"ANALYTICS": { "ANALYTICS": {
"pattern": r"^(GA|GOOGLE_ANALYTICS|ANALYTICS)_", "pattern": r"^(GA|GOOGLE_ANALYTICS|ANALYTICS)_",
"name": "Аналитика", "name": "Аналитика",
"description": "Настройки систем аналитики" "description": "Настройки систем аналитики",
}, },
} }
# Переменные, которые следует всегда помечать как секретные # Переменные, которые следует всегда помечать как секретные
SECRET_VARS_PATTERNS = [ SECRET_VARS_PATTERNS = [
r".*TOKEN.*", r".*SECRET.*", r".*PASSWORD.*", r".*KEY.*", r".*TOKEN.*",
r".*PWD.*", r".*PASS.*", r".*CRED.*", r".*_DSN.*", r".*SECRET.*",
r".*JWT.*", r".*SESSION.*", r".*OAUTH.*", r".*PASSWORD.*",
r".*GITHUB.*", r".*GOOGLE.*", r".*FACEBOOK.*" r".*KEY.*",
r".*PWD.*",
r".*PASS.*",
r".*CRED.*",
r".*_DSN.*",
r".*JWT.*",
r".*SESSION.*",
r".*OAUTH.*",
r".*GITHUB.*",
r".*GOOGLE.*",
r".*FACEBOOK.*",
] ]
def __init__(self): def __init__(self):
self.redis = Redis.from_url(REDIS_URL) self.redis = Redis.from_url(REDIS_URL)
self.prefix = "env:" self.prefix = "env:"
self.env_file_path = os.path.join(ROOT_DIR, '.env') self.env_file_path = os.path.join(ROOT_DIR, ".env")
def get_all_variables(self) -> List[EnvSection]: def get_all_variables(self) -> List[EnvSection]:
""" """
@ -142,15 +179,15 @@ class EnvManager:
env_vars = {} env_vars = {}
if os.path.exists(self.env_file_path): if os.path.exists(self.env_file_path):
try: try:
with open(self.env_file_path, 'r') as f: with open(self.env_file_path, "r") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
# Пропускаем пустые строки и комментарии # Пропускаем пустые строки и комментарии
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
# Разделяем строку на ключ и значение # Разделяем строку на ключ и значение
if '=' in line: if "=" in line:
key, value = line.split('=', 1) key, value = line.split("=", 1)
key = key.strip() key = key.strip()
value = value.strip() value = value.strip()
# Удаляем кавычки, если они есть # Удаляем кавычки, если они есть
@ -207,17 +244,17 @@ class EnvManager:
""" """
Определяет тип переменной на основе ее значения Определяет тип переменной на основе ее значения
""" """
if value.lower() in ('true', 'false'): if value.lower() in ("true", "false"):
return "boolean" return "boolean"
if value.isdigit(): if value.isdigit():
return "integer" return "integer"
if re.match(r"^\d+\.\d+$", value): if re.match(r"^\d+\.\d+$", value):
return "float" return "float"
# Проверяем на JSON объект или массив # Проверяем на JSON объект или массив
if (value.startswith('{') and value.endswith('}')) or (value.startswith('[') and value.endswith(']')): if (value.startswith("{") and value.endswith("}")) or (value.startswith("[") and value.endswith("]")):
return "json" return "json"
# Проверяем на URL # Проверяем на URL
if value.startswith(('http://', 'https://', 'redis://', 'postgresql://')): if value.startswith(("http://", "https://", "redis://", "postgresql://")):
return "url" return "url"
return "string" return "string"
@ -233,14 +270,9 @@ class EnvManager:
for key, value in variables.items(): for key, value in variables.items():
is_secret = self._is_secret_variable(key) is_secret = self._is_secret_variable(key)
var_type = self._determine_variable_type(value) var_type = self._determine_variable_type(value)
var = EnvVariable( var = EnvVariable(key=key, value=value, type=var_type, is_secret=is_secret)
key=key,
value=value,
type=var_type,
is_secret=is_secret
)
# Определяем секцию для переменной # Определяем секцию для переменной
placed = False placed = False
for section_id, section_config in self.SECTIONS.items(): for section_id, section_config in self.SECTIONS.items():
@ -248,7 +280,7 @@ class EnvManager:
sections_dict[section_id].append(var) sections_dict[section_id].append(var)
placed = True placed = True
break break
# Если переменная не попала ни в одну секцию # Если переменная не попала ни в одну секцию
# if not placed: # if not placed:
# other_variables.append(var) # other_variables.append(var)
@ -260,22 +292,20 @@ class EnvManager:
section_config = self.SECTIONS[section_id] section_config = self.SECTIONS[section_id]
result.append( result.append(
EnvSection( EnvSection(
name=section_config["name"], name=section_config["name"], description=section_config["description"], variables=variables
description=section_config["description"],
variables=variables
) )
) )
# Добавляем прочие переменные, если они есть # Добавляем прочие переменные, если они есть
if other_variables: if other_variables:
result.append( result.append(
EnvSection( EnvSection(
name="Прочие переменные", name="Прочие переменные",
description="Переменные, не вошедшие в основные категории", description="Переменные, не вошедшие в основные категории",
variables=other_variables variables=other_variables,
) )
) )
return result return result
def update_variable(self, key: str, value: str) -> bool: def update_variable(self, key: str, value: str) -> bool:
@ -286,13 +316,13 @@ class EnvManager:
# Сохраняем в Redis # Сохраняем в Redis
full_key = f"{self.prefix}{key}" full_key = f"{self.prefix}{key}"
self.redis.set(full_key, value) self.redis.set(full_key, value)
# Обновляем значение в .env файле # Обновляем значение в .env файле
self._update_dotenv_var(key, value) self._update_dotenv_var(key, value)
# Обновляем переменную в текущем процессе # Обновляем переменную в текущем процессе
os.environ[key] = value os.environ[key] = value
return True return True
except Exception as e: except Exception as e:
logger.error(f"Ошибка обновления переменной {key}: {e}") logger.error(f"Ошибка обновления переменной {key}: {e}")
@ -305,20 +335,20 @@ class EnvManager:
try: try:
# Если файл .env не существует, создаем его # Если файл .env не существует, создаем его
if not os.path.exists(self.env_file_path): if not os.path.exists(self.env_file_path):
with open(self.env_file_path, 'w') as f: with open(self.env_file_path, "w") as f:
f.write(f"{key}={value}\n") f.write(f"{key}={value}\n")
return True return True
# Если файл существует, читаем его содержимое # Если файл существует, читаем его содержимое
lines = [] lines = []
found = False found = False
with open(self.env_file_path, 'r') as f: with open(self.env_file_path, "r") as f:
for line in f: for line in f:
if line.strip() and not line.strip().startswith('#'): if line.strip() and not line.strip().startswith("#"):
if line.strip().startswith(f"{key}="): if line.strip().startswith(f"{key}="):
# Экранируем значение, если необходимо # Экранируем значение, если необходимо
if ' ' in value or ',' in value or '"' in value or "'" in value: if " " in value or "," in value or '"' in value or "'" in value:
escaped_value = f'"{value}"' escaped_value = f'"{value}"'
else: else:
escaped_value = value escaped_value = value
@ -328,20 +358,20 @@ class EnvManager:
lines.append(line) lines.append(line)
else: else:
lines.append(line) lines.append(line)
# Если переменной не было в файле, добавляем ее # Если переменной не было в файле, добавляем ее
if not found: if not found:
# Экранируем значение, если необходимо # Экранируем значение, если необходимо
if ' ' in value or ',' in value or '"' in value or "'" in value: if " " in value or "," in value or '"' in value or "'" in value:
escaped_value = f'"{value}"' escaped_value = f'"{value}"'
else: else:
escaped_value = value escaped_value = value
lines.append(f"{key}={escaped_value}\n") lines.append(f"{key}={escaped_value}\n")
# Записываем обновленный файл # Записываем обновленный файл
with open(self.env_file_path, 'w') as f: with open(self.env_file_path, "w") as f:
f.writelines(lines) f.writelines(lines)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Ошибка обновления .env файла: {e}") logger.error(f"Ошибка обновления .env файла: {e}")
@ -358,14 +388,14 @@ class EnvManager:
full_key = f"{self.prefix}{var.key}" full_key = f"{self.prefix}{var.key}"
pipe.set(full_key, var.value) pipe.set(full_key, var.value)
pipe.execute() pipe.execute()
# Обновляем переменные в .env файле # Обновляем переменные в .env файле
for var in variables: for var in variables:
self._update_dotenv_var(var.key, var.value) self._update_dotenv_var(var.key, var.value)
# Обновляем переменную в текущем процессе # Обновляем переменную в текущем процессе
os.environ[var.key] = var.value os.environ[var.key] = var.value
return True return True
except Exception as e: except Exception as e:
logger.error(f"Ошибка массового обновления переменных: {e}") logger.error(f"Ошибка массового обновления переменных: {e}")

View File

@ -93,9 +93,7 @@ async def notify_draft(draft_data, action: str = "publish"):
# Если переданы связанные атрибуты, добавим их # Если переданы связанные атрибуты, добавим их
if hasattr(draft_data, "topics") and draft_data.topics is not None: if hasattr(draft_data, "topics") and draft_data.topics is not None:
draft_payload["topics"] = [ draft_payload["topics"] = [{"id": t.id, "name": t.name, "slug": t.slug} for t in draft_data.topics]
{"id": t.id, "name": t.name, "slug": t.slug} for t in draft_data.topics
]
if hasattr(draft_data, "authors") and draft_data.authors is not None: if hasattr(draft_data, "authors") and draft_data.authors is not None:
draft_payload["authors"] = [ draft_payload["authors"] = [

View File

@ -30,7 +30,7 @@ class RedisService:
if self._client is None: if self._client is None:
await self.connect() await self.connect()
logger.info(f"[redis] Автоматически установлено соединение при выполнении команды {command}") logger.info(f"[redis] Автоматически установлено соединение при выполнении команды {command}")
if self._client: if self._client:
try: try:
logger.debug(f"{command}") # {args[0]}") # {args} {kwargs}") logger.debug(f"{command}") # {args[0]}") # {args} {kwargs}")
@ -55,14 +55,14 @@ class RedisService:
if self._client is None: if self._client is None:
# Выбрасываем исключение, так как pipeline нельзя создать до подключения # Выбрасываем исключение, так как pipeline нельзя создать до подключения
raise Exception("Redis client is not initialized. Call redis.connect() first.") raise Exception("Redis client is not initialized. Call redis.connect() first.")
return self._client.pipeline() return self._client.pipeline()
async def subscribe(self, *channels): async def subscribe(self, *channels):
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
async with self._client.pubsub() as pubsub: async with self._client.pubsub() as pubsub:
for channel in channels: for channel in channels:
await pubsub.subscribe(channel) await pubsub.subscribe(channel)
@ -71,7 +71,7 @@ class RedisService:
async def unsubscribe(self, *channels): async def unsubscribe(self, *channels):
if self._client is None: if self._client is None:
return return
async with self._client.pubsub() as pubsub: async with self._client.pubsub() as pubsub:
for channel in channels: for channel in channels:
await pubsub.unsubscribe(channel) await pubsub.unsubscribe(channel)
@ -81,14 +81,14 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
await self._client.publish(channel, data) await self._client.publish(channel, data)
async def set(self, key, data, ex=None): async def set(self, key, data, ex=None):
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
# Prepare the command arguments # Prepare the command arguments
args = [key, data] args = [key, data]
@ -104,7 +104,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self.execute("get", key) return await self.execute("get", key)
async def delete(self, *keys): async def delete(self, *keys):
@ -119,11 +119,11 @@ class RedisService:
""" """
if not keys: if not keys:
return 0 return 0
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self._client.delete(*keys) return await self._client.delete(*keys)
async def hmset(self, key, mapping): async def hmset(self, key, mapping):
@ -137,7 +137,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
await self._client.hset(key, mapping=mapping) await self._client.hset(key, mapping=mapping)
async def expire(self, key, seconds): async def expire(self, key, seconds):
@ -151,7 +151,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
await self._client.expire(key, seconds) await self._client.expire(key, seconds)
async def sadd(self, key, *values): async def sadd(self, key, *values):
@ -165,7 +165,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
await self._client.sadd(key, *values) await self._client.sadd(key, *values)
async def srem(self, key, *values): async def srem(self, key, *values):
@ -179,7 +179,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
await self._client.srem(key, *values) await self._client.srem(key, *values)
async def smembers(self, key): async def smembers(self, key):
@ -195,9 +195,9 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self._client.smembers(key) return await self._client.smembers(key)
async def exists(self, key): async def exists(self, key):
""" """
Проверяет, существует ли ключ в Redis. Проверяет, существует ли ключ в Redis.
@ -210,10 +210,10 @@ class RedisService:
""" """
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self._client.exists(key) return await self._client.exists(key)
async def expire(self, key, seconds): async def expire(self, key, seconds):
""" """
Устанавливает время жизни ключа. Устанавливает время жизни ключа.
@ -225,7 +225,7 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self._client.expire(key, seconds) return await self._client.expire(key, seconds)
async def keys(self, pattern): async def keys(self, pattern):
@ -238,10 +238,8 @@ class RedisService:
# Автоматически подключаемся к Redis, если соединение не установлено # Автоматически подключаемся к Redis, если соединение не установлено
if self._client is None: if self._client is None:
await self.connect() await self.connect()
return await self._client.keys(pattern) return await self._client.keys(pattern)
redis = RedisService() redis = RedisService()

View File

@ -12,7 +12,7 @@ resolvers = [query, mutation, type_draft]
def create_all_tables(): def create_all_tables():
"""Create all database tables in the correct order.""" """Create all database tables in the correct order."""
from auth.orm import Author, AuthorFollower, AuthorBookmark, AuthorRating from auth.orm import Author, AuthorBookmark, AuthorFollower, AuthorRating
from orm import community, draft, notification, reaction, shout, topic from orm import community, draft, notification, reaction, shout, topic
# Порядок важен - сначала таблицы без внешних ключей, затем зависимые таблицы # Порядок важен - сначала таблицы без внешних ключей, затем зависимые таблицы

View File

@ -2,9 +2,11 @@ import asyncio
import json import json
import logging import logging
import os import os
import httpx
import time
import random import random
import time
import httpx
from settings import TXTAI_SERVICE_URL from settings import TXTAI_SERVICE_URL
# Set up proper logging # Set up proper logging
@ -15,23 +17,15 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING)
# Configuration for search service # Configuration for search service
SEARCH_ENABLED = bool( SEARCH_ENABLED = bool(os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"])
os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"]
)
MAX_BATCH_SIZE = int(os.environ.get("SEARCH_MAX_BATCH_SIZE", "25")) MAX_BATCH_SIZE = int(os.environ.get("SEARCH_MAX_BATCH_SIZE", "25"))
# Search cache configuration # Search cache configuration
SEARCH_CACHE_ENABLED = bool( SEARCH_CACHE_ENABLED = bool(os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"])
os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"] SEARCH_CACHE_TTL_SECONDS = int(os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300")) # Default: 15 minutes
)
SEARCH_CACHE_TTL_SECONDS = int(
os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300")
) # Default: 15 minutes
SEARCH_PREFETCH_SIZE = int(os.environ.get("SEARCH_PREFETCH_SIZE", "200")) SEARCH_PREFETCH_SIZE = int(os.environ.get("SEARCH_PREFETCH_SIZE", "200"))
SEARCH_USE_REDIS = bool( SEARCH_USE_REDIS = bool(os.environ.get("SEARCH_USE_REDIS", "true").lower() in ["true", "1", "yes"])
os.environ.get("SEARCH_USE_REDIS", "true").lower() in ["true", "1", "yes"]
)
search_offset = 0 search_offset = 0
@ -68,9 +62,7 @@ class SearchCache:
serialized_results, serialized_results,
ex=self.ttl, ex=self.ttl,
) )
logger.info( logger.info(f"Stored {len(results)} search results for query '{query}' in Redis")
f"Stored {len(results)} search results for query '{query}' in Redis"
)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error storing search results in Redis: {e}") logger.error(f"Error storing search results in Redis: {e}")
@ -83,9 +75,7 @@ class SearchCache:
# Store results and update timestamp # Store results and update timestamp
self.cache[normalized_query] = results self.cache[normalized_query] = results
self.last_accessed[normalized_query] = time.time() self.last_accessed[normalized_query] = time.time()
logger.info( logger.info(f"Cached {len(results)} search results for query '{query}' in memory")
f"Cached {len(results)} search results for query '{query}' in memory"
)
return True return True
async def get(self, query, limit=10, offset=0): async def get(self, query, limit=10, offset=0):
@ -117,14 +107,10 @@ class SearchCache:
# Return paginated subset # Return paginated subset
end_idx = min(offset + limit, len(all_results)) end_idx = min(offset + limit, len(all_results))
if offset >= len(all_results): if offset >= len(all_results):
logger.warning( logger.warning(f"Requested offset {offset} exceeds result count {len(all_results)}")
f"Requested offset {offset} exceeds result count {len(all_results)}"
)
return [] return []
logger.info( logger.info(f"Cache hit for '{query}': serving {offset}:{end_idx} of {len(all_results)} results")
f"Cache hit for '{query}': serving {offset}:{end_idx} of {len(all_results)} results"
)
return all_results[offset:end_idx] return all_results[offset:end_idx]
async def has_query(self, query): async def has_query(self, query):
@ -174,11 +160,7 @@ class SearchCache:
"""Remove oldest entries if memory cache is full""" """Remove oldest entries if memory cache is full"""
now = time.time() now = time.time()
# First remove expired entries # First remove expired entries
expired_keys = [ expired_keys = [key for key, last_access in self.last_accessed.items() if now - last_access > self.ttl]
key
for key, last_access in self.last_accessed.items()
if now - last_access > self.ttl
]
for key in expired_keys: for key in expired_keys:
if key in self.cache: if key in self.cache:
@ -217,9 +199,7 @@ class SearchService:
if SEARCH_CACHE_ENABLED: if SEARCH_CACHE_ENABLED:
cache_location = "Redis" if SEARCH_USE_REDIS else "Memory" cache_location = "Redis" if SEARCH_USE_REDIS else "Memory"
logger.info( logger.info(f"Search caching enabled using {cache_location} cache with TTL={SEARCH_CACHE_TTL_SECONDS}s")
f"Search caching enabled using {cache_location} cache with TTL={SEARCH_CACHE_TTL_SECONDS}s"
)
async def info(self): async def info(self):
"""Return information about search service""" """Return information about search service"""
@ -270,9 +250,7 @@ class SearchService:
logger.info( logger.info(
f"Document verification complete: {bodies_missing_count} bodies missing, {titles_missing_count} titles missing" f"Document verification complete: {bodies_missing_count} bodies missing, {titles_missing_count} titles missing"
) )
logger.info( logger.info(f"Total unique missing documents: {total_missing_count} out of {len(doc_ids)} total")
f"Total unique missing documents: {total_missing_count} out of {len(doc_ids)} total"
)
# Return in a backwards-compatible format plus the detailed breakdown # Return in a backwards-compatible format plus the detailed breakdown
return { return {
@ -308,9 +286,7 @@ class SearchService:
# 1. Index title if available # 1. Index title if available
if hasattr(shout, "title") and shout.title and isinstance(shout.title, str): if hasattr(shout, "title") and shout.title and isinstance(shout.title, str):
title_doc = {"id": str(shout.id), "title": shout.title.strip()} title_doc = {"id": str(shout.id), "title": shout.title.strip()}
indexing_tasks.append( indexing_tasks.append(self.index_client.post("/index-title", json=title_doc))
self.index_client.post("/index-title", json=title_doc)
)
# 2. Index body content (subtitle, lead, body) # 2. Index body content (subtitle, lead, body)
body_text_parts = [] body_text_parts = []
@ -346,9 +322,7 @@ class SearchService:
body_text = body_text[:MAX_TEXT_LENGTH] body_text = body_text[:MAX_TEXT_LENGTH]
body_doc = {"id": str(shout.id), "body": body_text} body_doc = {"id": str(shout.id), "body": body_text}
indexing_tasks.append( indexing_tasks.append(self.index_client.post("/index-body", json=body_doc))
self.index_client.post("/index-body", json=body_doc)
)
# 3. Index authors # 3. Index authors
authors = getattr(shout, "authors", []) authors = getattr(shout, "authors", [])
@ -373,30 +347,22 @@ class SearchService:
if name: if name:
author_doc = {"id": author_id, "name": name, "bio": combined_bio} author_doc = {"id": author_id, "name": name, "bio": combined_bio}
indexing_tasks.append( indexing_tasks.append(self.index_client.post("/index-author", json=author_doc))
self.index_client.post("/index-author", json=author_doc)
)
# Run all indexing tasks in parallel # Run all indexing tasks in parallel
if indexing_tasks: if indexing_tasks:
responses = await asyncio.gather( responses = await asyncio.gather(*indexing_tasks, return_exceptions=True)
*indexing_tasks, return_exceptions=True
)
# Check for errors in responses # Check for errors in responses
for i, response in enumerate(responses): for i, response in enumerate(responses):
if isinstance(response, Exception): if isinstance(response, Exception):
logger.error(f"Error in indexing task {i}: {response}") logger.error(f"Error in indexing task {i}: {response}")
elif ( elif hasattr(response, "status_code") and response.status_code >= 400:
hasattr(response, "status_code") and response.status_code >= 400
):
logger.error( logger.error(
f"Error response in indexing task {i}: {response.status_code}, {await response.text()}" f"Error response in indexing task {i}: {response.status_code}, {await response.text()}"
) )
logger.info( logger.info(f"Document {shout.id} indexed across {len(indexing_tasks)} endpoints")
f"Document {shout.id} indexed across {len(indexing_tasks)} endpoints"
)
else: else:
logger.warning(f"No content to index for shout {shout.id}") logger.warning(f"No content to index for shout {shout.id}")
@ -424,24 +390,14 @@ class SearchService:
for shout in shouts: for shout in shouts:
try: try:
# 1. Process title documents # 1. Process title documents
if ( if hasattr(shout, "title") and shout.title and isinstance(shout.title, str):
hasattr(shout, "title") title_docs.append({"id": str(shout.id), "title": shout.title.strip()})
and shout.title
and isinstance(shout.title, str)
):
title_docs.append(
{"id": str(shout.id), "title": shout.title.strip()}
)
# 2. Process body documents (subtitle, lead, body) # 2. Process body documents (subtitle, lead, body)
body_text_parts = [] body_text_parts = []
for field_name in ["subtitle", "lead", "body"]: for field_name in ["subtitle", "lead", "body"]:
field_value = getattr(shout, field_name, None) field_value = getattr(shout, field_name, None)
if ( if field_value and isinstance(field_value, str) and field_value.strip():
field_value
and isinstance(field_value, str)
and field_value.strip()
):
body_text_parts.append(field_value.strip()) body_text_parts.append(field_value.strip())
# Process media content if available # Process media content if available
@ -507,9 +463,7 @@ class SearchService:
} }
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}")
f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}"
)
total_skipped += 1 total_skipped += 1
# Convert author dict to list # Convert author dict to list
@ -543,9 +497,7 @@ class SearchService:
logger.info(f"Indexing {len(documents)} {doc_type} documents") logger.info(f"Indexing {len(documents)} {doc_type} documents")
# Categorize documents by size # Categorize documents by size
small_docs, medium_docs, large_docs = self._categorize_by_size( small_docs, medium_docs, large_docs = self._categorize_by_size(documents, doc_type)
documents, doc_type
)
# Process each category with appropriate batch sizes # Process each category with appropriate batch sizes
batch_sizes = { batch_sizes = {
@ -561,9 +513,7 @@ class SearchService:
]: ]:
if docs: if docs:
batch_size = batch_sizes[category] batch_size = batch_sizes[category]
await self._process_batches( await self._process_batches(docs, batch_size, endpoint, f"{doc_type}-{category}")
docs, batch_size, endpoint, f"{doc_type}-{category}"
)
def _categorize_by_size(self, documents, doc_type): def _categorize_by_size(self, documents, doc_type):
"""Categorize documents by size for optimized batch processing""" """Categorize documents by size for optimized batch processing"""
@ -599,7 +549,7 @@ class SearchService:
"""Process document batches with retry logic""" """Process document batches with retry logic"""
for i in range(0, len(documents), batch_size): for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size] batch = documents[i : i + batch_size]
batch_id = f"{batch_prefix}-{i//batch_size + 1}" batch_id = f"{batch_prefix}-{i // batch_size + 1}"
retry_count = 0 retry_count = 0
max_retries = 3 max_retries = 3
@ -607,9 +557,7 @@ class SearchService:
while not success and retry_count < max_retries: while not success and retry_count < max_retries:
try: try:
response = await self.index_client.post( response = await self.index_client.post(endpoint, json=batch, timeout=90.0)
endpoint, json=batch, timeout=90.0
)
if response.status_code == 422: if response.status_code == 422:
error_detail = response.json() error_detail = response.json()
@ -630,13 +578,13 @@ class SearchService:
batch[:mid], batch[:mid],
batch_size // 2, batch_size // 2,
endpoint, endpoint,
f"{batch_prefix}-{i//batch_size}-A", f"{batch_prefix}-{i // batch_size}-A",
) )
await self._process_batches( await self._process_batches(
batch[mid:], batch[mid:],
batch_size // 2, batch_size // 2,
endpoint, endpoint,
f"{batch_prefix}-{i//batch_size}-B", f"{batch_prefix}-{i // batch_size}-B",
) )
else: else:
logger.error( logger.error(
@ -649,9 +597,7 @@ class SearchService:
def _truncate_error_detail(self, error_detail): def _truncate_error_detail(self, error_detail):
"""Truncate error details for logging""" """Truncate error details for logging"""
truncated_detail = ( truncated_detail = error_detail.copy() if isinstance(error_detail, dict) else error_detail
error_detail.copy() if isinstance(error_detail, dict) else error_detail
)
if ( if (
isinstance(truncated_detail, dict) isinstance(truncated_detail, dict)
@ -660,30 +606,22 @@ class SearchService:
): ):
for i, item in enumerate(truncated_detail["detail"]): for i, item in enumerate(truncated_detail["detail"]):
if isinstance(item, dict) and "input" in item: if isinstance(item, dict) and "input" in item:
if isinstance(item["input"], dict) and any( if isinstance(item["input"], dict) and any(k in item["input"] for k in ["documents", "text"]):
k in item["input"] for k in ["documents", "text"] if "documents" in item["input"] and isinstance(item["input"]["documents"], list):
):
if "documents" in item["input"] and isinstance(
item["input"]["documents"], list
):
for j, doc in enumerate(item["input"]["documents"]): for j, doc in enumerate(item["input"]["documents"]):
if ( if "text" in doc and isinstance(doc["text"], str) and len(doc["text"]) > 100:
"text" in doc item["input"]["documents"][j]["text"] = (
and isinstance(doc["text"], str) f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]"
and len(doc["text"]) > 100 )
):
item["input"]["documents"][j][
"text"
] = f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]"
if ( if (
"text" in item["input"] "text" in item["input"]
and isinstance(item["input"]["text"], str) and isinstance(item["input"]["text"], str)
and len(item["input"]["text"]) > 100 and len(item["input"]["text"]) > 100
): ):
item["input"][ item["input"]["text"] = (
"text" f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]"
] = f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]" )
return truncated_detail return truncated_detail
@ -711,9 +649,9 @@ class SearchService:
search_limit = SEARCH_PREFETCH_SIZE search_limit = SEARCH_PREFETCH_SIZE
else: else:
search_limit = limit search_limit = limit
logger.info(f"Searching for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})") logger.info(f"Searching for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})")
response = await self.client.post( response = await self.client.post(
"/search-combined", "/search-combined",
json={"text": text, "limit": search_limit}, json={"text": text, "limit": search_limit},
@ -767,9 +705,7 @@ class SearchService:
logger.info( logger.info(
f"Searching authors for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})" f"Searching authors for: '{text}' (limit={limit}, offset={offset}, search_limit={search_limit})"
) )
response = await self.client.post( response = await self.client.post("/search-author", json={"text": text, "limit": search_limit})
"/search-author", json={"text": text, "limit": search_limit}
)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@ -784,7 +720,7 @@ class SearchService:
# Store the full prefetch batch, then page it # Store the full prefetch batch, then page it
await self.cache.store(cache_key, author_results) await self.cache.store(cache_key, author_results)
return await self.cache.get(cache_key, limit, offset) return await self.cache.get(cache_key, limit, offset)
return author_results[offset : offset + limit] return author_results[offset : offset + limit]
except Exception as e: except Exception as e:
@ -802,9 +738,7 @@ class SearchService:
result = response.json() result = response.json()
if result.get("consistency", {}).get("status") != "ok": if result.get("consistency", {}).get("status") != "ok":
null_count = result.get("consistency", {}).get( null_count = result.get("consistency", {}).get("null_embeddings_count", 0)
"null_embeddings_count", 0
)
if null_count > 0: if null_count > 0:
logger.warning(f"Found {null_count} documents with NULL embeddings") logger.warning(f"Found {null_count} documents with NULL embeddings")
@ -877,14 +811,10 @@ async def initialize_search_index(shouts_data):
index_status = await search_service.check_index_status() index_status = await search_service.check_index_status()
if index_status.get("status") == "inconsistent": if index_status.get("status") == "inconsistent":
problem_ids = index_status.get("consistency", {}).get( problem_ids = index_status.get("consistency", {}).get("null_embeddings_sample", [])
"null_embeddings_sample", []
)
if problem_ids: if problem_ids:
problem_docs = [ problem_docs = [shout for shout in shouts_data if str(shout.id) in problem_ids]
shout for shout in shouts_data if str(shout.id) in problem_ids
]
if problem_docs: if problem_docs:
await search_service.bulk_index(problem_docs) await search_service.bulk_index(problem_docs)
@ -902,9 +832,7 @@ async def initialize_search_index(shouts_data):
if isinstance(media, str): if isinstance(media, str):
try: try:
media_json = json.loads(media) media_json = json.loads(media)
if isinstance(media_json, dict) and ( if isinstance(media_json, dict) and (media_json.get("title") or media_json.get("body")):
media_json.get("title") or media_json.get("body")
):
return True return True
except Exception: except Exception:
return True return True
@ -922,13 +850,9 @@ async def initialize_search_index(shouts_data):
if verification.get("status") == "error": if verification.get("status") == "error":
return return
# Only reindex missing docs that actually have body content # Only reindex missing docs that actually have body content
missing_ids = [ missing_ids = [mid for mid in verification.get("missing", []) if mid in body_ids]
mid for mid in verification.get("missing", []) if mid in body_ids
]
if missing_ids: if missing_ids:
missing_docs = [ missing_docs = [shout for shout in shouts_with_body if str(shout.id) in missing_ids]
shout for shout in shouts_with_body if str(shout.id) in missing_ids
]
await search_service.bulk_index(missing_docs) await search_service.bulk_index(missing_docs)
else: else:
pass pass
@ -955,35 +879,35 @@ async def check_search_service():
print(f"[WARNING] Search service unavailable: {info.get('message', 'unknown reason')}") print(f"[WARNING] Search service unavailable: {info.get('message', 'unknown reason')}")
else: else:
print(f"[INFO] Search service is available: {info}") print(f"[INFO] Search service is available: {info}")
# Initialize search index in the background # Initialize search index in the background
async def initialize_search_index_background(): async def initialize_search_index_background():
""" """
Запускает индексацию поиска в фоновом режиме с низким приоритетом. Запускает индексацию поиска в фоновом режиме с низким приоритетом.
Эта функция: Эта функция:
1. Загружает все shouts из базы данных 1. Загружает все shouts из базы данных
2. Индексирует их в поисковом сервисе 2. Индексирует их в поисковом сервисе
3. Выполняется асинхронно, не блокируя основной поток 3. Выполняется асинхронно, не блокируя основной поток
4. Обрабатывает возможные ошибки, не прерывая работу приложения 4. Обрабатывает возможные ошибки, не прерывая работу приложения
Индексация запускается с задержкой после инициализации сервера, Индексация запускается с задержкой после инициализации сервера,
чтобы не создавать дополнительную нагрузку при запуске. чтобы не создавать дополнительную нагрузку при запуске.
""" """
try: try:
print("[search] Starting background search indexing process") print("[search] Starting background search indexing process")
from services.db import fetch_all_shouts from services.db import fetch_all_shouts
# Get total count first (optional) # Get total count first (optional)
all_shouts = await fetch_all_shouts() all_shouts = await fetch_all_shouts()
total_count = len(all_shouts) if all_shouts else 0 total_count = len(all_shouts) if all_shouts else 0
print(f"[search] Fetched {total_count} shouts for background indexing") print(f"[search] Fetched {total_count} shouts for background indexing")
if not all_shouts: if not all_shouts:
print("[search] No shouts found for indexing, skipping search index initialization") print("[search] No shouts found for indexing, skipping search index initialization")
return return
# Start the indexing process with the fetched shouts # Start the indexing process with the fetched shouts
print("[search] Beginning background search index initialization...") print("[search] Beginning background search index initialization...")
await initialize_search_index(all_shouts) await initialize_search_index(all_shouts)

View File

@ -80,12 +80,12 @@ class ViewedStorage:
# Получаем список всех ключей migrated_views_* и находим самый последний # Получаем список всех ключей migrated_views_* и находим самый последний
keys = await redis.execute("KEYS", "migrated_views_*") keys = await redis.execute("KEYS", "migrated_views_*")
logger.info(f" * Raw Redis result for 'KEYS migrated_views_*': {len(keys)}") logger.info(f" * Raw Redis result for 'KEYS migrated_views_*': {len(keys)}")
# Декодируем байтовые строки, если есть # Декодируем байтовые строки, если есть
if keys and isinstance(keys[0], bytes): if keys and isinstance(keys[0], bytes):
keys = [k.decode('utf-8') for k in keys] keys = [k.decode("utf-8") for k in keys]
logger.info(f" * Decoded keys: {keys}") logger.info(f" * Decoded keys: {keys}")
if not keys: if not keys:
logger.warning(" * No migrated_views keys found in Redis") logger.warning(" * No migrated_views keys found in Redis")
return return
@ -93,7 +93,7 @@ class ViewedStorage:
# Фильтруем только ключи timestamp формата (исключаем migrated_views_slugs) # Фильтруем только ключи timestamp формата (исключаем migrated_views_slugs)
timestamp_keys = [k for k in keys if k != "migrated_views_slugs"] timestamp_keys = [k for k in keys if k != "migrated_views_slugs"]
logger.info(f" * Timestamp keys after filtering: {timestamp_keys}") logger.info(f" * Timestamp keys after filtering: {timestamp_keys}")
if not timestamp_keys: if not timestamp_keys:
logger.warning(" * No migrated_views timestamp keys found in Redis") logger.warning(" * No migrated_views timestamp keys found in Redis")
return return
@ -243,20 +243,12 @@ class ViewedStorage:
# Обновление тем и авторов с использованием вспомогательной функции # Обновление тем и авторов с использованием вспомогательной функции
for [_st, topic] in ( for [_st, topic] in (
session.query(ShoutTopic, Topic) session.query(ShoutTopic, Topic).join(Topic).join(Shout).where(Shout.slug == shout_slug).all()
.join(Topic)
.join(Shout)
.where(Shout.slug == shout_slug)
.all()
): ):
update_groups(self.shouts_by_topic, topic.slug, shout_slug) update_groups(self.shouts_by_topic, topic.slug, shout_slug)
for [_st, author] in ( for [_st, author] in (
session.query(ShoutAuthor, Author) session.query(ShoutAuthor, Author).join(Author).join(Shout).where(Shout.slug == shout_slug).all()
.join(Author)
.join(Shout)
.where(Shout.slug == shout_slug)
.all()
): ):
update_groups(self.shouts_by_author, author.slug, shout_slug) update_groups(self.shouts_by_author, author.slug, shout_slug)
@ -289,9 +281,7 @@ class ViewedStorage:
if failed == 0: if failed == 0:
when = datetime.now(timezone.utc) + timedelta(seconds=self.period) when = datetime.now(timezone.utc) + timedelta(seconds=self.period)
t = format(when.astimezone().isoformat()) t = format(when.astimezone().isoformat())
logger.info( logger.info(" ⎩ next update: %s" % (t.split("T")[0] + " " + t.split("T")[1].split(".")[0]))
" ⎩ next update: %s" % (t.split("T")[0] + " " + t.split("T")[1].split(".")[0])
)
await asyncio.sleep(self.period) await asyncio.sleep(self.period)
else: else:
await asyncio.sleep(10) await asyncio.sleep(10)

View File

@ -72,4 +72,4 @@ MAILGUN_API_KEY = os.getenv("MAILGUN_API_KEY", "")
MAILGUN_DOMAIN = os.getenv("MAILGUN_DOMAIN", "discours.io") MAILGUN_DOMAIN = os.getenv("MAILGUN_DOMAIN", "discours.io")
TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "none") TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "none")

View File

@ -1,6 +1,7 @@
import pytest
from typing import Dict from typing import Dict
import pytest
@pytest.fixture @pytest.fixture
def oauth_settings() -> Dict[str, Dict[str, str]]: def oauth_settings() -> Dict[str, Dict[str, str]]:

View File

@ -1,8 +1,9 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from starlette.responses import JSONResponse, RedirectResponse from starlette.responses import JSONResponse, RedirectResponse
from auth.oauth import get_user_profile, oauth_login, oauth_callback from auth.oauth import get_user_profile, oauth_callback, oauth_login
# Подменяем настройки для тестов # Подменяем настройки для тестов
with ( with (

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
import pytest import pytest
from services.redis import redis from services.redis import redis
from tests.test_config import get_test_client from tests.test_config import get_test_client