tests-passed
This commit is contained in:
@@ -1,25 +1,20 @@
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from io import TextIOWrapper
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import create_engine, event, exc, func, inspect
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.orm import Session, configure_mappers, joinedload
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, configure_mappers
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from orm.base import BaseModel
|
||||
from settings import DB_URL
|
||||
from utils.logger import root_logger as logger
|
||||
|
||||
# Global variables
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration
|
||||
engine = create_engine(DB_URL, echo=False, poolclass=StaticPool if "sqlite" in DB_URL else None)
|
||||
ENGINE = engine # Backward compatibility alias
|
||||
@@ -64,8 +59,8 @@ def get_statement_from_context(context: Connection) -> str | None:
|
||||
try:
|
||||
# Безопасное форматирование параметров
|
||||
query = compiled_statement % compiled_parameters
|
||||
except Exception as e:
|
||||
logger.exception(f"Error formatting query: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error formatting query")
|
||||
else:
|
||||
query = compiled_statement
|
||||
if query:
|
||||
@@ -130,41 +125,28 @@ def get_json_builder() -> tuple[Any, Any, Any]:
|
||||
# Используем их в коде
|
||||
json_builder, json_array_builder, json_cast = get_json_builder()
|
||||
|
||||
# Fetch all shouts, with authors preloaded
|
||||
# This function is used for search indexing
|
||||
|
||||
|
||||
def fetch_all_shouts(session: Session | None = None) -> list[Any]:
|
||||
"""Fetch all published shouts for search indexing with authors preloaded"""
|
||||
from orm.shout import Shout
|
||||
|
||||
close_session = False
|
||||
if session is None:
|
||||
session = local_session()
|
||||
close_session = True
|
||||
def create_table_if_not_exists(connection_or_engine: Connection | Engine, model_cls: Type[DeclarativeBase]) -> None:
|
||||
"""Creates table for the given model if it doesn't exist"""
|
||||
# If an Engine is passed, get a connection from it
|
||||
connection = connection_or_engine.connect() if isinstance(connection_or_engine, Engine) else connection_or_engine
|
||||
|
||||
try:
|
||||
# Fetch only published and non-deleted shouts with authors preloaded
|
||||
query = (
|
||||
session.query(Shout)
|
||||
.options(joinedload(Shout.authors))
|
||||
.filter(Shout.published_at is not None, Shout.deleted_at is None)
|
||||
)
|
||||
return query.all()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching shouts for search indexing: {e}")
|
||||
return []
|
||||
inspector = inspect(connection)
|
||||
if not inspector.has_table(model_cls.__tablename__):
|
||||
# Use SQLAlchemy's built-in table creation instead of manual SQL generation
|
||||
from sqlalchemy.schema import CreateTable
|
||||
|
||||
create_stmt = CreateTable(model_cls.__table__) # type: ignore[arg-type]
|
||||
connection.execute(create_stmt)
|
||||
logger.info(f"Created table: {model_cls.__tablename__}")
|
||||
finally:
|
||||
if close_session:
|
||||
# Подавляем SQLAlchemy deprecated warning для синхронной сессии
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
session.close()
|
||||
# If we created a connection from an Engine, close it
|
||||
if isinstance(connection_or_engine, Engine):
|
||||
connection.close()
|
||||
|
||||
|
||||
def get_column_names_without_virtual(model_cls: type[BaseModel]) -> list[str]:
|
||||
def get_column_names_without_virtual(model_cls: Type[DeclarativeBase]) -> list[str]:
|
||||
"""Получает имена колонок модели без виртуальных полей"""
|
||||
try:
|
||||
column_names: list[str] = [
|
||||
@@ -175,23 +157,6 @@ def get_column_names_without_virtual(model_cls: type[BaseModel]) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def get_primary_key_columns(model_cls: type[BaseModel]) -> list[str]:
|
||||
"""Получает имена первичных ключей модели"""
|
||||
try:
|
||||
return [col.name for col in model_cls.__table__.primary_key.columns]
|
||||
except AttributeError:
|
||||
return ["id"]
|
||||
|
||||
|
||||
def create_table_if_not_exists(engine: Engine, model_cls: type[BaseModel]) -> None:
|
||||
"""Creates table for the given model if it doesn't exist"""
|
||||
if hasattr(model_cls, "__tablename__"):
|
||||
inspector = inspect(engine)
|
||||
if not inspector.has_table(model_cls.__tablename__):
|
||||
model_cls.__table__.create(engine)
|
||||
logger.info(f"Created table: {model_cls.__tablename__}")
|
||||
|
||||
|
||||
def format_sql_warning(
|
||||
message: str | Warning,
|
||||
category: type[Warning],
|
||||
@@ -207,19 +172,11 @@ def format_sql_warning(
|
||||
# Apply the custom warning formatter
|
||||
def _set_warning_formatter() -> None:
|
||||
"""Set custom warning formatter"""
|
||||
import warnings
|
||||
|
||||
original_formatwarning = warnings.formatwarning
|
||||
|
||||
def custom_formatwarning(
|
||||
message: Warning | str,
|
||||
category: type[Warning],
|
||||
filename: str,
|
||||
lineno: int,
|
||||
file: TextIOWrapper | None = None,
|
||||
line: str | None = None,
|
||||
message: str, category: type[Warning], filename: str, lineno: int, line: str | None = None
|
||||
) -> str:
|
||||
return format_sql_warning(message, category, filename, lineno, file, line)
|
||||
return f"{category.__name__}: {message}\n"
|
||||
|
||||
warnings.formatwarning = custom_formatwarning # type: ignore[assignment]
|
||||
|
||||
|
Reference in New Issue
Block a user