Compare commits
71 Commits
dev
...
feat/sv-se
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e1d1096674 | ||
![]() |
82870a4e47 | ||
![]() |
80b909d801 | ||
![]() |
1ada0a02f9 | ||
![]() |
44aef147b5 | ||
![]() |
2bebfbd4df | ||
![]() |
f19248184a | ||
![]() |
7df9361daa | ||
![]() |
e38a1c1338 | ||
![]() |
1281157d93 | ||
![]() |
0018749905 | ||
![]() |
c344fcee2d | ||
![]() |
a1a61a6731 | ||
![]() |
8d6ad2c84f | ||
![]() |
beba1992e9 | ||
![]() |
b0296d7747 | ||
![]() |
98e3dff35e | ||
![]() |
3782a9dffb | ||
![]() |
93c00b3dd1 | ||
![]() |
fac43e5997 | ||
![]() |
e7facf8d87 | ||
![]() |
3062a2b7de | ||
![]() |
c0406dbbf2 | ||
![]() |
ab4610575f | ||
![]() |
5425dbf832 | ||
![]() |
a10db2d38a | ||
![]() |
83e70856cd | ||
![]() |
11654dba68 | ||
![]() |
ec9465ad40 | ||
![]() |
4d965fb27b | ||
![]() |
e382cc1ea5 | ||
83d61ca76d | |||
![]() |
106222b0e0 | ||
![]() |
c533241d1e | ||
![]() |
78326047bf | ||
![]() |
bc4ec79240 | ||
![]() |
a0db5707c4 | ||
![]() |
ecc443c3ad | ||
![]() |
9a02ca74ad | ||
![]() |
9ebb81cbd3 | ||
![]() |
0bc55977ac | ||
![]() |
ff3a4debce | ||
![]() |
ae85b32f69 | ||
![]() |
34a354e9e3 | ||
![]() |
e405fb527b | ||
![]() |
7f36f93d92 | ||
![]() |
f089a32394 | ||
![]() |
1fd623a660 | ||
![]() |
88012f1b8c | ||
![]() |
6e284640c0 | ||
![]() |
077cb46482 | ||
![]() |
60a13a9097 | ||
![]() |
316375bf18 | ||
![]() |
fb820f67fd | ||
![]() |
f1d9f4e036 | ||
![]() |
ebb67eb311 | ||
![]() |
50a8c24ead | ||
![]() |
eb4b9363ab | ||
![]() |
19c5028a0c | ||
![]() |
57e1e8e6bd | ||
![]() |
385057ffcd | ||
![]() |
90699768ff | ||
![]() |
ad0ca75aa9 | ||
![]() |
39242d5e6c | ||
![]() |
24cca7f2cb | ||
![]() |
a9c7ac49d6 | ||
![]() |
f249752db5 | ||
![]() |
c0b2116da2 | ||
![]() |
59e71c8144 | ||
![]() |
e6a416383d | ||
![]() |
d55448398d |
|
@ -29,7 +29,16 @@ jobs:
|
||||||
if: github.ref == 'refs/heads/dev'
|
if: github.ref == 'refs/heads/dev'
|
||||||
uses: dokku/github-action@master
|
uses: dokku/github-action@master
|
||||||
with:
|
with:
|
||||||
branch: 'dev'
|
branch: 'main'
|
||||||
force: true
|
force: true
|
||||||
git_remote_url: 'ssh://dokku@v2.discours.io:22/core'
|
git_remote_url: 'ssh://dokku@v2.discours.io:22/core'
|
||||||
ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }}
|
ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Push to dokku for staging branch
|
||||||
|
if: github.ref == 'refs/heads/staging'
|
||||||
|
uses: dokku/github-action@master
|
||||||
|
with:
|
||||||
|
branch: 'dev'
|
||||||
|
git_remote_url: 'ssh://dokku@staging.discours.io:22/core'
|
||||||
|
ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
git_push_flags: '--force'
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -128,6 +128,9 @@ dmypy.json
|
||||||
.idea
|
.idea
|
||||||
temp.*
|
temp.*
|
||||||
|
|
||||||
|
# Debug
|
||||||
|
DEBUG.log
|
||||||
|
|
||||||
discours.key
|
discours.key
|
||||||
discours.crt
|
discours.crt
|
||||||
discours.pem
|
discours.pem
|
||||||
|
@ -162,5 +165,4 @@ views.json
|
||||||
*.crt
|
*.crt
|
||||||
*cache.json
|
*cache.json
|
||||||
.cursor
|
.cursor
|
||||||
|
.devcontainer/
|
||||||
node_modules/
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ FROM python:slim
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
postgresql-client \
|
postgresql-client \
|
||||||
curl \
|
curl \
|
||||||
|
build-essential \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
1
app/resolvers/draft.py
Normal file
1
app/resolvers/draft.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
|
119
auth/usermodel.py
Normal file
119
auth/usermodel.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
JSON,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from services.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Permission(Base):
|
||||||
|
__tablename__ = "permission"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
||||||
|
resource = Column(String, nullable=False)
|
||||||
|
operation = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Base):
|
||||||
|
__tablename__ = "role"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
permissions = relationship(Permission)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizerUser(Base):
|
||||||
|
__tablename__ = "authorizer_users"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
||||||
|
key = Column(String)
|
||||||
|
email = Column(String, unique=True)
|
||||||
|
email_verified_at = Column(Integer)
|
||||||
|
family_name = Column(String)
|
||||||
|
gender = Column(String)
|
||||||
|
given_name = Column(String)
|
||||||
|
is_multi_factor_auth_enabled = Column(Boolean)
|
||||||
|
middle_name = Column(String)
|
||||||
|
nickname = Column(String)
|
||||||
|
password = Column(String)
|
||||||
|
phone_number = Column(String, unique=True)
|
||||||
|
phone_number_verified_at = Column(Integer)
|
||||||
|
# preferred_username = Column(String, nullable=False)
|
||||||
|
picture = Column(String)
|
||||||
|
revoked_timestamp = Column(Integer)
|
||||||
|
roles = Column(String, default="author,reader")
|
||||||
|
signup_methods = Column(String, default="magic_link_login")
|
||||||
|
created_at = Column(Integer, default=lambda: int(time.time()))
|
||||||
|
updated_at = Column(Integer, default=lambda: int(time.time()))
|
||||||
|
|
||||||
|
|
||||||
|
class UserRating(Base):
|
||||||
|
__tablename__ = "user_rating"
|
||||||
|
|
||||||
|
id = None
|
||||||
|
rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
value: Column = Column(Integer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(Base):
|
||||||
|
__tablename__ = "user_role"
|
||||||
|
|
||||||
|
id = None
|
||||||
|
user = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
role = Column(ForeignKey("role.id"), primary_key=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
default_user = None
|
||||||
|
|
||||||
|
email = Column(String, unique=True, nullable=False, comment="Email")
|
||||||
|
username = Column(String, nullable=False, comment="Login")
|
||||||
|
password = Column(String, nullable=True, comment="Password")
|
||||||
|
bio = Column(String, nullable=True, comment="Bio") # status description
|
||||||
|
about = Column(String, nullable=True, comment="About") # long and formatted
|
||||||
|
userpic = Column(String, nullable=True, comment="Userpic")
|
||||||
|
name = Column(String, nullable=True, comment="Display name")
|
||||||
|
slug = Column(String, unique=True, comment="User's slug")
|
||||||
|
links = Column(JSON, nullable=True, comment="Links")
|
||||||
|
oauth = Column(String, nullable=True)
|
||||||
|
oid = Column(String, nullable=True)
|
||||||
|
|
||||||
|
muted = Column(Boolean, default=False)
|
||||||
|
confirmed = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at")
|
||||||
|
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Updated at")
|
||||||
|
last_seen = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at")
|
||||||
|
deleted_at = Column(DateTime(timezone=True), nullable=True, comment="Deleted at")
|
||||||
|
|
||||||
|
ratings = relationship(UserRating, foreign_keys=UserRating.user)
|
||||||
|
roles = relationship(lambda: Role, secondary=UserRole.__tablename__)
|
||||||
|
|
||||||
|
def get_permission(self):
|
||||||
|
scope = {}
|
||||||
|
for role in self.roles:
|
||||||
|
for p in role.permissions:
|
||||||
|
if p.resource not in scope:
|
||||||
|
scope[p.resource] = set()
|
||||||
|
scope[p.resource].add(p.operation)
|
||||||
|
print(scope)
|
||||||
|
return scope
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print(User.get_permission(user_id=1))
|
12
cache/precache.py
vendored
12
cache/precache.py
vendored
|
@ -77,11 +77,15 @@ async def precache_topics_followers(topic_id: int, session):
|
||||||
|
|
||||||
async def precache_data():
|
async def precache_data():
|
||||||
logger.info("precaching...")
|
logger.info("precaching...")
|
||||||
|
logger.debug("Entering precache_data")
|
||||||
try:
|
try:
|
||||||
key = "authorizer_env"
|
key = "authorizer_env"
|
||||||
|
logger.debug(f"Fetching existing hash for key '{key}' from Redis")
|
||||||
# cache reset
|
# cache reset
|
||||||
value = await redis.execute("HGETALL", key)
|
value = await redis.execute("HGETALL", key)
|
||||||
|
logger.debug(f"Fetched value for '{key}': {value}")
|
||||||
await redis.execute("FLUSHDB")
|
await redis.execute("FLUSHDB")
|
||||||
|
logger.debug("Redis database flushed")
|
||||||
logger.info("redis: FLUSHDB")
|
logger.info("redis: FLUSHDB")
|
||||||
|
|
||||||
# Преобразуем словарь в список аргументов для HSET
|
# Преобразуем словарь в список аргументов для HSET
|
||||||
|
@ -97,21 +101,27 @@ async def precache_data():
|
||||||
await redis.execute("HSET", key, *value)
|
await redis.execute("HSET", key, *value)
|
||||||
logger.info(f"redis hash '{key}' was restored")
|
logger.info(f"redis hash '{key}' was restored")
|
||||||
|
|
||||||
|
logger.info("Beginning topic precache phase")
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
# topics
|
# topics
|
||||||
q = select(Topic).where(Topic.community == 1)
|
q = select(Topic).where(Topic.community == 1)
|
||||||
topics = get_with_stat(q)
|
topics = get_with_stat(q)
|
||||||
|
logger.info(f"Found {len(topics)} topics to precache")
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
topic_dict = topic.dict() if hasattr(topic, "dict") else topic
|
topic_dict = topic.dict() if hasattr(topic, "dict") else topic
|
||||||
|
logger.debug(f"Precaching topic id={topic_dict.get('id')}")
|
||||||
await cache_topic(topic_dict)
|
await cache_topic(topic_dict)
|
||||||
|
logger.debug(f"Cached topic id={topic_dict.get('id')}")
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
precache_topics_followers(topic_dict["id"], session),
|
precache_topics_followers(topic_dict["id"], session),
|
||||||
precache_topics_authors(topic_dict["id"], session),
|
precache_topics_authors(topic_dict["id"], session),
|
||||||
)
|
)
|
||||||
|
logger.debug(f"Finished precaching followers and authors for topic id={topic_dict.get('id')}")
|
||||||
logger.info(f"{len(topics)} topics and their followings precached")
|
logger.info(f"{len(topics)} topics and their followings precached")
|
||||||
|
|
||||||
# authors
|
# authors
|
||||||
authors = get_with_stat(select(Author).where(Author.user.is_not(None)))
|
authors = get_with_stat(select(Author).where(Author.user.is_not(None)))
|
||||||
|
logger.info(f"Found {len(authors)} authors to precache")
|
||||||
logger.info(f"{len(authors)} authors found in database")
|
logger.info(f"{len(authors)} authors found in database")
|
||||||
for author in authors:
|
for author in authors:
|
||||||
if isinstance(author, Author):
|
if isinstance(author, Author):
|
||||||
|
@ -119,10 +129,12 @@ async def precache_data():
|
||||||
author_id = profile.get("id")
|
author_id = profile.get("id")
|
||||||
user_id = profile.get("user", "").strip()
|
user_id = profile.get("user", "").strip()
|
||||||
if author_id and user_id:
|
if author_id and user_id:
|
||||||
|
logger.debug(f"Precaching author id={author_id}")
|
||||||
await cache_author(profile)
|
await cache_author(profile)
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
precache_authors_followers(author_id, session), precache_authors_follows(author_id, session)
|
precache_authors_followers(author_id, session), precache_authors_follows(author_id, session)
|
||||||
)
|
)
|
||||||
|
logger.debug(f"Finished precaching followers and follows for author id={author_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"fail caching {author}")
|
logger.error(f"fail caching {author}")
|
||||||
logger.info(f"{len(authors)} authors and their followings precached")
|
logger.info(f"{len(authors)} authors and their followings precached")
|
||||||
|
|
24
cache/triggers.py
vendored
24
cache/triggers.py
vendored
|
@ -88,11 +88,7 @@ def after_reaction_handler(mapper, connection, target):
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
shout = (
|
shout = (
|
||||||
session.query(Shout)
|
session.query(Shout)
|
||||||
.filter(
|
.filter(Shout.id == shout_id, Shout.published_at.is_not(None), Shout.deleted_at.is_(None))
|
||||||
Shout.id == shout_id,
|
|
||||||
Shout.published_at.is_not(None),
|
|
||||||
Shout.deleted_at.is_(None),
|
|
||||||
)
|
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -112,27 +108,15 @@ def events_register():
|
||||||
|
|
||||||
event.listen(AuthorFollower, "after_insert", after_follower_handler)
|
event.listen(AuthorFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(AuthorFollower, "after_update", after_follower_handler)
|
event.listen(AuthorFollower, "after_update", after_follower_handler)
|
||||||
event.listen(
|
event.listen(AuthorFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
||||||
AuthorFollower,
|
|
||||||
"after_delete",
|
|
||||||
lambda *args: after_follower_handler(*args, is_delete=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
event.listen(TopicFollower, "after_insert", after_follower_handler)
|
event.listen(TopicFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(TopicFollower, "after_update", after_follower_handler)
|
event.listen(TopicFollower, "after_update", after_follower_handler)
|
||||||
event.listen(
|
event.listen(TopicFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
||||||
TopicFollower,
|
|
||||||
"after_delete",
|
|
||||||
lambda *args: after_follower_handler(*args, is_delete=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler)
|
event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(ShoutReactionsFollower, "after_update", after_follower_handler)
|
event.listen(ShoutReactionsFollower, "after_update", after_follower_handler)
|
||||||
event.listen(
|
event.listen(ShoutReactionsFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
||||||
ShoutReactionsFollower,
|
|
||||||
"after_delete",
|
|
||||||
lambda *args: after_follower_handler(*args, is_delete=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
event.listen(Reaction, "after_update", mark_for_revalidation)
|
event.listen(Reaction, "after_update", mark_for_revalidation)
|
||||||
event.listen(Author, "after_update", mark_for_revalidation)
|
event.listen(Author, "after_update", mark_for_revalidation)
|
||||||
|
|
81
main.py
81
main.py
|
@ -7,7 +7,6 @@ from os.path import exists
|
||||||
from ariadne import load_schema_from_path, make_executable_schema
|
from ariadne import load_schema_from_path, make_executable_schema
|
||||||
from ariadne.asgi import GraphQL
|
from ariadne.asgi import GraphQL
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.middleware import Middleware
|
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
|
@ -18,7 +17,7 @@ from cache.revalidator import revalidation_manager
|
||||||
from services.exception import ExceptionHandlerMiddleware
|
from services.exception import ExceptionHandlerMiddleware
|
||||||
from services.redis import redis
|
from services.redis import redis
|
||||||
from services.schema import create_all_tables, resolvers
|
from services.schema import create_all_tables, resolvers
|
||||||
from services.search import search_service
|
from services.search import search_service, initialize_search_index
|
||||||
from services.viewed import ViewedStorage
|
from services.viewed import ViewedStorage
|
||||||
from services.webhook import WebhookEndpoint, create_webhook_endpoint
|
from services.webhook import WebhookEndpoint, create_webhook_endpoint
|
||||||
from settings import DEV_SERVER_PID_FILE_NAME, MODE
|
from settings import DEV_SERVER_PID_FILE_NAME, MODE
|
||||||
|
@ -35,24 +34,79 @@ async def start():
|
||||||
f.write(str(os.getpid()))
|
f.write(str(os.getpid()))
|
||||||
print(f"[main] process started in {MODE} mode")
|
print(f"[main] process started in {MODE} mode")
|
||||||
|
|
||||||
|
async def check_search_service():
|
||||||
|
"""Check if search service is available and log result"""
|
||||||
|
info = await search_service.info()
|
||||||
|
if info.get("status") in ["error", "unavailable"]:
|
||||||
|
print(f"[WARNING] Search service unavailable: {info.get('message', 'unknown reason')}")
|
||||||
|
else:
|
||||||
|
print(f"[INFO] Search service is available: {info}")
|
||||||
|
|
||||||
|
# Helper to run precache with timeout and catch errors
|
||||||
|
async def precache_with_timeout():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(precache_data(), timeout=60)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("[precache] Precache timed out after 60 seconds")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[precache] Error during precache: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# indexing DB data
|
||||||
|
# async def indexing():
|
||||||
|
# from services.db import fetch_all_shouts
|
||||||
|
# all_shouts = await fetch_all_shouts()
|
||||||
|
# await initialize_search_index(all_shouts)
|
||||||
async def lifespan(_app):
|
async def lifespan(_app):
|
||||||
try:
|
try:
|
||||||
|
print("[lifespan] Starting application initialization")
|
||||||
create_all_tables()
|
create_all_tables()
|
||||||
|
|
||||||
|
# schedule precaching in background with timeout and error handling
|
||||||
|
asyncio.create_task(precache_with_timeout())
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
redis.connect(),
|
redis.connect(),
|
||||||
precache_data(),
|
|
||||||
ViewedStorage.init(),
|
ViewedStorage.init(),
|
||||||
create_webhook_endpoint(),
|
create_webhook_endpoint(),
|
||||||
search_service.info(),
|
check_search_service(),
|
||||||
start(),
|
start(),
|
||||||
revalidation_manager.start(),
|
revalidation_manager.start(),
|
||||||
)
|
)
|
||||||
|
print("[lifespan] Basic initialization complete")
|
||||||
|
|
||||||
|
# Add a delay before starting the intensive search indexing
|
||||||
|
print("[lifespan] Waiting for system stabilization before search indexing...")
|
||||||
|
await asyncio.sleep(10) # 10-second delay to let the system stabilize
|
||||||
|
|
||||||
|
# Start search indexing as a background task with lower priority
|
||||||
|
asyncio.create_task(initialize_search_index_background())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
print("[lifespan] Shutting down application services")
|
||||||
tasks = [redis.disconnect(), ViewedStorage.stop(), revalidation_manager.stop()]
|
tasks = [redis.disconnect(), ViewedStorage.stop(), revalidation_manager.stop()]
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
print("[lifespan] Shutdown complete")
|
||||||
|
|
||||||
|
# Initialize search index in the background
|
||||||
|
async def initialize_search_index_background():
|
||||||
|
"""Run search indexing as a background task with low priority"""
|
||||||
|
try:
|
||||||
|
print("[search] Starting background search indexing process")
|
||||||
|
from services.db import fetch_all_shouts
|
||||||
|
|
||||||
|
# Get total count first (optional)
|
||||||
|
all_shouts = await fetch_all_shouts()
|
||||||
|
total_count = len(all_shouts) if all_shouts else 0
|
||||||
|
print(f"[search] Fetched {total_count} shouts for background indexing")
|
||||||
|
|
||||||
|
# Start the indexing process with the fetched shouts
|
||||||
|
print("[search] Beginning background search index initialization...")
|
||||||
|
await initialize_search_index(all_shouts)
|
||||||
|
print("[search] Background search index initialization complete")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[search] Error in background search indexing: {str(e)}")
|
||||||
|
|
||||||
# Создаем экземпляр GraphQL
|
# Создаем экземпляр GraphQL
|
||||||
graphql_app = GraphQL(schema, debug=True)
|
graphql_app = GraphQL(schema, debug=True)
|
||||||
|
@ -74,24 +128,6 @@ async def graphql_handler(request: Request):
|
||||||
print(f"GraphQL error: {str(e)}")
|
print(f"GraphQL error: {str(e)}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
middleware = [
|
|
||||||
# Начинаем с обработки ошибок
|
|
||||||
Middleware(ExceptionHandlerMiddleware),
|
|
||||||
# CORS должен быть перед другими middleware для корректной обработки preflight-запросов
|
|
||||||
Middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=[
|
|
||||||
"https://localhost:3000",
|
|
||||||
"https://testing.discours.io",
|
|
||||||
"https://testing3.discours.io",
|
|
||||||
"https://discours.io",
|
|
||||||
"https://new.discours.io"
|
|
||||||
],
|
|
||||||
allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS
|
|
||||||
allow_headers=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Обновляем маршрут в Starlette
|
# Обновляем маршрут в Starlette
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
|
@ -99,7 +135,6 @@ app = Starlette(
|
||||||
Route("/", graphql_handler, methods=["GET", "POST"]),
|
Route("/", graphql_handler, methods=["GET", "POST"]),
|
||||||
Route("/new-author", WebhookEndpoint),
|
Route("/new-author", WebhookEndpoint),
|
||||||
],
|
],
|
||||||
middleware=middleware,
|
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
debug=True,
|
debug=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
log_format custom '$remote_addr - $remote_user [$time_local] "$request" '
|
log_format custom '$remote_addr - $remote_user [$time_local] "$request" '
|
||||||
'origin=$http_origin status=$status '
|
'origin=$http_origin allow_origin=$allow_origin status=$status '
|
||||||
'"$http_referer" "$http_user_agent"';
|
'"$http_referer" "$http_user_agent"';
|
||||||
|
|
||||||
{{ $proxy_settings := "proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection $http_connection; proxy_set_header Host $http_host; proxy_set_header X-Request-Start $msec;" }}
|
{{ $proxy_settings := "proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection $http_connection; proxy_set_header Host $http_host; proxy_set_header X-Request-Start $msec;" }}
|
||||||
|
@ -49,6 +49,34 @@ server {
|
||||||
{{ $proxy_settings }}
|
{{ $proxy_settings }}
|
||||||
{{ $gzip_settings }}
|
{{ $gzip_settings }}
|
||||||
|
|
||||||
|
# Handle CORS for OPTIONS method
|
||||||
|
if ($request_method = 'OPTIONS') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' $allow_origin always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS';
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
|
||||||
|
add_header 'Access-Control-Allow-Credentials' 'true' always;
|
||||||
|
add_header 'Access-Control-Max-Age' 1728000;
|
||||||
|
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||||
|
add_header 'Content-Length' 0;
|
||||||
|
return 204;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle CORS for POST method
|
||||||
|
if ($request_method = 'POST') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' $allow_origin always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS' always;
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
|
||||||
|
add_header 'Access-Control-Allow-Credentials' 'true' always;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle CORS for GET method
|
||||||
|
if ($request_method = 'GET') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' $allow_origin always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS' always;
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
|
||||||
|
add_header 'Access-Control-Allow-Credentials' 'true' always;
|
||||||
|
}
|
||||||
|
|
||||||
proxy_cache my_cache;
|
proxy_cache my_cache;
|
||||||
proxy_cache_revalidate on;
|
proxy_cache_revalidate on;
|
||||||
proxy_cache_min_uses 2;
|
proxy_cache_min_uses 2;
|
||||||
|
@ -68,6 +96,13 @@ server {
|
||||||
|
|
||||||
location ~* \.(mp3|wav|ogg|flac|aac|aif|webm)$ {
|
location ~* \.(mp3|wav|ogg|flac|aac|aif|webm)$ {
|
||||||
proxy_pass http://{{ $.APP }}-{{ $upstream_port }};
|
proxy_pass http://{{ $.APP }}-{{ $upstream_port }};
|
||||||
|
if ($request_method = 'GET') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' $allow_origin always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always;
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||||
|
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||||
|
add_header 'Access-Control-Allow-Credentials' 'true' always;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,7 @@ class Author(Base):
|
||||||
Модель автора в системе.
|
Модель автора в системе.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
user (str): Идентификатор пользователя в системе авторизации
|
||||||
name (str): Отображаемое имя
|
name (str): Отображаемое имя
|
||||||
slug (str): Уникальный строковый идентификатор
|
slug (str): Уникальный строковый идентификатор
|
||||||
bio (str): Краткая биография/статус
|
bio (str): Краткая биография/статус
|
||||||
|
@ -104,6 +105,8 @@ class Author(Base):
|
||||||
|
|
||||||
__tablename__ = "author"
|
__tablename__ = "author"
|
||||||
|
|
||||||
|
user = Column(String) # unbounded link with authorizer's User type
|
||||||
|
|
||||||
name = Column(String, nullable=True, comment="Display name")
|
name = Column(String, nullable=True, comment="Display name")
|
||||||
slug = Column(String, unique=True, comment="Author's slug")
|
slug = Column(String, unique=True, comment="Author's slug")
|
||||||
bio = Column(String, nullable=True, comment="Bio") # status description
|
bio = Column(String, nullable=True, comment="Bio") # status description
|
||||||
|
@ -121,14 +124,12 @@ class Author(Base):
|
||||||
|
|
||||||
# Определяем индексы
|
# Определяем индексы
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
# Индекс для быстрого поиска по имени
|
|
||||||
Index("idx_author_name", "name"),
|
|
||||||
# Индекс для быстрого поиска по slug
|
# Индекс для быстрого поиска по slug
|
||||||
Index("idx_author_slug", "slug"),
|
Index("idx_author_slug", "slug"),
|
||||||
|
# Индекс для быстрого поиска по идентификатору пользователя
|
||||||
|
Index("idx_author_user", "user"),
|
||||||
# Индекс для фильтрации неудаленных авторов
|
# Индекс для фильтрации неудаленных авторов
|
||||||
Index(
|
Index("idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)),
|
||||||
"idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)
|
|
||||||
),
|
|
||||||
# Индекс для сортировки по времени создания (для новых авторов)
|
# Индекс для сортировки по времени создания (для новых авторов)
|
||||||
Index("idx_author_created_at", "created_at"),
|
Index("idx_author_created_at", "created_at"),
|
||||||
# Индекс для сортировки по времени последнего посещения
|
# Индекс для сортировки по времени последнего посещения
|
||||||
|
|
|
@ -6,6 +6,7 @@ from sqlalchemy.orm import relationship
|
||||||
from orm.author import Author
|
from orm.author import Author
|
||||||
from orm.topic import Topic
|
from orm.topic import Topic
|
||||||
from services.db import Base
|
from services.db import Base
|
||||||
|
from orm.shout import Shout
|
||||||
|
|
||||||
|
|
||||||
class DraftTopic(Base):
|
class DraftTopic(Base):
|
||||||
|
|
28
orm/shout.py
28
orm/shout.py
|
@ -71,6 +71,34 @@ class ShoutAuthor(Base):
|
||||||
class Shout(Base):
|
class Shout(Base):
|
||||||
"""
|
"""
|
||||||
Публикация в системе.
|
Публикация в системе.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
body (str)
|
||||||
|
slug (str)
|
||||||
|
cover (str) : "Cover image url"
|
||||||
|
cover_caption (str) : "Cover image alt caption"
|
||||||
|
lead (str)
|
||||||
|
title (str)
|
||||||
|
subtitle (str)
|
||||||
|
layout (str)
|
||||||
|
media (dict)
|
||||||
|
authors (list[Author])
|
||||||
|
topics (list[Topic])
|
||||||
|
reactions (list[Reaction])
|
||||||
|
lang (str)
|
||||||
|
version_of (int)
|
||||||
|
oid (str)
|
||||||
|
seo (str) : JSON
|
||||||
|
draft (int)
|
||||||
|
created_at (int)
|
||||||
|
updated_at (int)
|
||||||
|
published_at (int)
|
||||||
|
featured_at (int)
|
||||||
|
deleted_at (int)
|
||||||
|
created_by (int)
|
||||||
|
updated_by (int)
|
||||||
|
deleted_by (int)
|
||||||
|
community (int)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "shout"
|
__tablename__ = "shout"
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
[tool.ruff]
|
|
||||||
line-length = 108
|
|
|
@ -13,6 +13,10 @@ starlette
|
||||||
gql
|
gql
|
||||||
ariadne
|
ariadne
|
||||||
granian
|
granian
|
||||||
|
|
||||||
|
# NLP and search
|
||||||
|
httpx
|
||||||
|
|
||||||
orjson
|
orjson
|
||||||
pydantic
|
pydantic
|
||||||
trafilatura
|
trafilatura
|
|
@ -8,6 +8,7 @@ from resolvers.author import ( # search_authors,
|
||||||
get_author_id,
|
get_author_id,
|
||||||
get_authors_all,
|
get_authors_all,
|
||||||
load_authors_by,
|
load_authors_by,
|
||||||
|
load_authors_search,
|
||||||
update_author,
|
update_author,
|
||||||
)
|
)
|
||||||
from resolvers.community import get_communities_all, get_community
|
from resolvers.community import get_communities_all, get_community
|
||||||
|
@ -73,6 +74,7 @@ __all__ = [
|
||||||
"get_author_follows_authors",
|
"get_author_follows_authors",
|
||||||
"get_authors_all",
|
"get_authors_all",
|
||||||
"load_authors_by",
|
"load_authors_by",
|
||||||
|
"load_authors_search",
|
||||||
"update_author",
|
"update_author",
|
||||||
## "search_authors",
|
## "search_authors",
|
||||||
# community
|
# community
|
||||||
|
|
|
@ -20,6 +20,7 @@ from services.auth import login_required
|
||||||
from services.db import local_session
|
from services.db import local_session
|
||||||
from services.redis import redis
|
from services.redis import redis
|
||||||
from services.schema import mutation, query
|
from services.schema import mutation, query
|
||||||
|
from services.search import search_service
|
||||||
from utils.logger import root_logger as logger
|
from utils.logger import root_logger as logger
|
||||||
|
|
||||||
DEFAULT_COMMUNITIES = [1]
|
DEFAULT_COMMUNITIES = [1]
|
||||||
|
@ -301,6 +302,46 @@ async def load_authors_by(_, _info, by, limit, offset):
|
||||||
return await get_authors_with_stats(limit, offset, by)
|
return await get_authors_with_stats(limit, offset, by)
|
||||||
|
|
||||||
|
|
||||||
|
@query.field("load_authors_search")
|
||||||
|
async def load_authors_search(_, info, text: str, limit: int = 10, offset: int = 0):
|
||||||
|
"""
|
||||||
|
Resolver for searching authors by text. Works with txt-ai search endpony.
|
||||||
|
Args:
|
||||||
|
text: Search text
|
||||||
|
limit: Maximum number of authors to return
|
||||||
|
offset: Offset for pagination
|
||||||
|
Returns:
|
||||||
|
list: List of authors matching the search criteria
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get author IDs from search engine (already sorted by relevance)
|
||||||
|
search_results = await search_service.search_authors(text, limit, offset)
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
author_ids = [result.get("id") for result in search_results if result.get("id")]
|
||||||
|
if not author_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch full author objects from DB
|
||||||
|
with local_session() as session:
|
||||||
|
# Simple query to get authors by IDs - no need for stats here
|
||||||
|
authors_query = select(Author).filter(Author.id.in_(author_ids))
|
||||||
|
db_authors = session.execute(authors_query).scalars().all()
|
||||||
|
|
||||||
|
if not db_authors:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a dictionary for quick lookup
|
||||||
|
authors_dict = {str(author.id): author for author in db_authors}
|
||||||
|
|
||||||
|
# Keep the order from search results (maintains the relevance sorting)
|
||||||
|
ordered_authors = [authors_dict[author_id] for author_id in author_ids if author_id in authors_dict]
|
||||||
|
|
||||||
|
return ordered_authors
|
||||||
|
|
||||||
|
|
||||||
def get_author_id_from(slug="", user=None, author_id=None):
|
def get_author_id_from(slug="", user=None, author_id=None):
|
||||||
try:
|
try:
|
||||||
author_id = None
|
author_id = None
|
||||||
|
|
25
resolvers/pyrightconfig.json
Normal file
25
resolvers/pyrightconfig.json
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
{
|
||||||
|
"include": [
|
||||||
|
"."
|
||||||
|
],
|
||||||
|
"exclude": [
|
||||||
|
"**/node_modules",
|
||||||
|
"**/__pycache__",
|
||||||
|
"**/.*"
|
||||||
|
],
|
||||||
|
"defineConstant": {
|
||||||
|
"DEBUG": true
|
||||||
|
},
|
||||||
|
"venvPath": ".",
|
||||||
|
"venv": ".venv",
|
||||||
|
"pythonVersion": "3.11",
|
||||||
|
"typeCheckingMode": "strict",
|
||||||
|
"reportMissingImports": true,
|
||||||
|
"reportMissingTypeStubs": false,
|
||||||
|
"reportUnknownMemberType": false,
|
||||||
|
"reportUnknownParameterType": false,
|
||||||
|
"reportUnknownVariableType": false,
|
||||||
|
"reportUnknownArgumentType": false,
|
||||||
|
"reportPrivateUsage": false,
|
||||||
|
"reportUntypedFunctionDecorator": false
|
||||||
|
}
|
|
@ -10,7 +10,7 @@ from orm.shout import Shout, ShoutAuthor, ShoutTopic
|
||||||
from orm.topic import Topic
|
from orm.topic import Topic
|
||||||
from services.db import json_array_builder, json_builder, local_session
|
from services.db import json_array_builder, json_builder, local_session
|
||||||
from services.schema import query
|
from services.schema import query
|
||||||
from services.search import search_text
|
from services.search import search_text, get_search_count
|
||||||
from services.viewed import ViewedStorage
|
from services.viewed import ViewedStorage
|
||||||
from utils.logger import root_logger as logger
|
from utils.logger import root_logger as logger
|
||||||
|
|
||||||
|
@ -187,12 +187,10 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
|
||||||
"""
|
"""
|
||||||
shouts = []
|
shouts = []
|
||||||
try:
|
try:
|
||||||
# logger.info(f"Starting get_shouts_with_links with limit={limit}, offset={offset}")
|
|
||||||
q = q.limit(limit).offset(offset)
|
q = q.limit(limit).offset(offset)
|
||||||
|
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
shouts_result = session.execute(q).all()
|
shouts_result = session.execute(q).all()
|
||||||
# logger.info(f"Got {len(shouts_result) if shouts_result else 0} shouts from query")
|
|
||||||
|
|
||||||
if not shouts_result:
|
if not shouts_result:
|
||||||
logger.warning("No shouts found in query result")
|
logger.warning("No shouts found in query result")
|
||||||
|
@ -203,7 +201,6 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
|
||||||
shout = None
|
shout = None
|
||||||
if hasattr(row, "Shout"):
|
if hasattr(row, "Shout"):
|
||||||
shout = row.Shout
|
shout = row.Shout
|
||||||
# logger.debug(f"Processing shout#{shout.id} at index {idx}")
|
|
||||||
if shout:
|
if shout:
|
||||||
shout_id = int(f"{shout.id}")
|
shout_id = int(f"{shout.id}")
|
||||||
shout_dict = shout.dict()
|
shout_dict = shout.dict()
|
||||||
|
@ -231,20 +228,16 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
|
||||||
topics = None
|
topics = None
|
||||||
if has_field(info, "topics") and hasattr(row, "topics"):
|
if has_field(info, "topics") and hasattr(row, "topics"):
|
||||||
topics = orjson.loads(row.topics) if isinstance(row.topics, str) else row.topics
|
topics = orjson.loads(row.topics) if isinstance(row.topics, str) else row.topics
|
||||||
# logger.debug(f"Shout#{shout_id} topics: {topics}")
|
|
||||||
shout_dict["topics"] = topics
|
shout_dict["topics"] = topics
|
||||||
|
|
||||||
if has_field(info, "main_topic"):
|
if has_field(info, "main_topic"):
|
||||||
main_topic = None
|
main_topic = None
|
||||||
if hasattr(row, "main_topic"):
|
if hasattr(row, "main_topic"):
|
||||||
# logger.debug(f"Raw main_topic for shout#{shout_id}: {row.main_topic}")
|
|
||||||
main_topic = (
|
main_topic = (
|
||||||
orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic
|
orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic
|
||||||
)
|
)
|
||||||
# logger.debug(f"Parsed main_topic for shout#{shout_id}: {main_topic}")
|
|
||||||
|
|
||||||
if not main_topic and topics and len(topics) > 0:
|
if not main_topic and topics and len(topics) > 0:
|
||||||
# logger.info(f"No main_topic found for shout#{shout_id}, using first topic from list")
|
|
||||||
main_topic = {
|
main_topic = {
|
||||||
"id": topics[0]["id"],
|
"id": topics[0]["id"],
|
||||||
"title": topics[0]["title"],
|
"title": topics[0]["title"],
|
||||||
|
@ -252,10 +245,8 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
|
||||||
"is_main": True,
|
"is_main": True,
|
||||||
}
|
}
|
||||||
elif not main_topic:
|
elif not main_topic:
|
||||||
logger.warning(f"No main_topic and no topics found for shout#{shout_id}")
|
|
||||||
main_topic = {"id": 0, "title": "no topic", "slug": "notopic", "is_main": True}
|
main_topic = {"id": 0, "title": "no topic", "slug": "notopic", "is_main": True}
|
||||||
shout_dict["main_topic"] = main_topic
|
shout_dict["main_topic"] = main_topic
|
||||||
# logger.debug(f"Final main_topic for shout#{shout_id}: {main_topic}")
|
|
||||||
|
|
||||||
if has_field(info, "authors") and hasattr(row, "authors"):
|
if has_field(info, "authors") and hasattr(row, "authors"):
|
||||||
shout_dict["authors"] = (
|
shout_dict["authors"] = (
|
||||||
|
@ -282,7 +273,6 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
|
||||||
logger.error(f"Fatal error in get_shouts_with_links: {e}", exc_info=True)
|
logger.error(f"Fatal error in get_shouts_with_links: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"Returning {len(shouts)} shouts from get_shouts_with_links")
|
|
||||||
return shouts
|
return shouts
|
||||||
|
|
||||||
|
|
||||||
|
@ -401,32 +391,48 @@ async def load_shouts_search(_, info, text, options):
|
||||||
"""
|
"""
|
||||||
limit = options.get("limit", 10)
|
limit = options.get("limit", 10)
|
||||||
offset = options.get("offset", 0)
|
offset = options.get("offset", 0)
|
||||||
if isinstance(text, str) and len(text) > 2:
|
|
||||||
results = await search_text(text, limit, offset)
|
|
||||||
scores = {}
|
|
||||||
hits_ids = []
|
|
||||||
for sr in results:
|
|
||||||
shout_id = sr.get("id")
|
|
||||||
if shout_id:
|
|
||||||
shout_id = str(shout_id)
|
|
||||||
scores[shout_id] = sr.get("score")
|
|
||||||
hits_ids.append(shout_id)
|
|
||||||
|
|
||||||
q = (
|
if isinstance(text, str) and len(text) > 2:
|
||||||
query_with_stat(info)
|
# Get search results with pagination
|
||||||
if has_field(info, "stat")
|
results = await search_text(text, limit, offset)
|
||||||
else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)))
|
|
||||||
)
|
if not results:
|
||||||
q = q.filter(Shout.id.in_(hits_ids))
|
logger.info(f"No search results found for '{text}'")
|
||||||
q = apply_filters(q, options)
|
|
||||||
q = apply_sorting(q, options)
|
|
||||||
shouts = get_shouts_with_links(info, q, limit, offset)
|
|
||||||
for shout in shouts:
|
|
||||||
shout.score = scores[f"{shout.id}"]
|
|
||||||
shouts.sort(key=lambda x: x.score, reverse=True)
|
|
||||||
return shouts
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Extract IDs in the order from the search engine
|
||||||
|
hits_ids = [str(sr.get("id")) for sr in results if sr.get("id")]
|
||||||
|
|
||||||
|
# Query DB for only the IDs in the current page
|
||||||
|
q = query_with_stat(info)
|
||||||
|
q = q.filter(Shout.id.in_(hits_ids))
|
||||||
|
q = apply_filters(q, options.get("filters", {}))
|
||||||
|
|
||||||
|
shouts = get_shouts_with_links(info, q, len(hits_ids), 0)
|
||||||
|
|
||||||
|
# Reorder shouts to match the order from hits_ids
|
||||||
|
shouts_dict = {str(shout['id']): shout for shout in shouts}
|
||||||
|
ordered_shouts = [shouts_dict[shout_id] for shout_id in hits_ids if shout_id in shouts_dict]
|
||||||
|
|
||||||
|
return ordered_shouts
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@query.field("get_search_results_count")
|
||||||
|
async def get_search_results_count(_, info, text):
|
||||||
|
"""
|
||||||
|
Returns the total count of search results for a search query.
|
||||||
|
|
||||||
|
:param _: Root query object (unused)
|
||||||
|
:param info: GraphQL context information
|
||||||
|
:param text: Search query text
|
||||||
|
:return: Total count of results
|
||||||
|
"""
|
||||||
|
if isinstance(text, str) and len(text) > 2:
|
||||||
|
count = await get_search_count(text)
|
||||||
|
return {"count": count}
|
||||||
|
return {"count": 0}
|
||||||
|
|
||||||
|
|
||||||
@query.field("load_shouts_unrated")
|
@query.field("load_shouts_unrated")
|
||||||
async def load_shouts_unrated(_, info, options):
|
async def load_shouts_unrated(_, info, options):
|
||||||
|
|
|
@ -4,7 +4,7 @@ type Query {
|
||||||
get_author_id(user: String!): Author
|
get_author_id(user: String!): Author
|
||||||
get_authors_all: [Author]
|
get_authors_all: [Author]
|
||||||
load_authors_by(by: AuthorsBy!, limit: Int, offset: Int): [Author]
|
load_authors_by(by: AuthorsBy!, limit: Int, offset: Int): [Author]
|
||||||
# search_authors(what: String!): [Author]
|
load_authors_search(text: String!, limit: Int, offset: Int): [Author!] # Search for authors by name or bio
|
||||||
|
|
||||||
# community
|
# community
|
||||||
get_community: Community
|
get_community: Community
|
||||||
|
@ -33,6 +33,7 @@ type Query {
|
||||||
get_shout(slug: String, shout_id: Int): Shout
|
get_shout(slug: String, shout_id: Int): Shout
|
||||||
load_shouts_by(options: LoadShoutsOptions): [Shout]
|
load_shouts_by(options: LoadShoutsOptions): [Shout]
|
||||||
load_shouts_search(text: String!, options: LoadShoutsOptions): [SearchResult]
|
load_shouts_search(text: String!, options: LoadShoutsOptions): [SearchResult]
|
||||||
|
get_search_results_count(text: String!): CountResult!
|
||||||
load_shouts_bookmarked(options: LoadShoutsOptions): [Shout]
|
load_shouts_bookmarked(options: LoadShoutsOptions): [Shout]
|
||||||
|
|
||||||
# rating
|
# rating
|
||||||
|
|
|
@ -213,6 +213,7 @@ type CommonResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
type SearchResult {
|
type SearchResult {
|
||||||
|
id: Int!
|
||||||
slug: String!
|
slug: String!
|
||||||
title: String!
|
title: String!
|
||||||
cover: String
|
cover: String
|
||||||
|
@ -280,3 +281,7 @@ type MyRateComment {
|
||||||
my_rate: ReactionKind
|
my_rate: ReactionKind
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CountResult {
|
||||||
|
count: Int!
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
# This file makes services a Python package
|
|
|
@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||||
inspect,
|
inspect,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session, configure_mappers, declarative_base
|
from sqlalchemy.orm import Session, configure_mappers, declarative_base, joinedload
|
||||||
from sqlalchemy.sql.schema import Table
|
from sqlalchemy.sql.schema import Table
|
||||||
|
|
||||||
from settings import DB_URL
|
from settings import DB_URL
|
||||||
|
@ -259,3 +259,32 @@ def get_json_builder():
|
||||||
|
|
||||||
# Используем их в коде
|
# Используем их в коде
|
||||||
json_builder, json_array_builder, json_cast = get_json_builder()
|
json_builder, json_array_builder, json_cast = get_json_builder()
|
||||||
|
|
||||||
|
# Fetch all shouts, with authors preloaded
|
||||||
|
# This function is used for search indexing
|
||||||
|
|
||||||
|
async def fetch_all_shouts(session=None):
|
||||||
|
"""Fetch all published shouts for search indexing with authors preloaded"""
|
||||||
|
from orm.shout import Shout
|
||||||
|
|
||||||
|
close_session = False
|
||||||
|
if session is None:
|
||||||
|
session = local_session()
|
||||||
|
close_session = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch only published and non-deleted shouts with authors preloaded
|
||||||
|
query = session.query(Shout).options(
|
||||||
|
joinedload(Shout.authors)
|
||||||
|
).filter(
|
||||||
|
Shout.published_at.is_not(None),
|
||||||
|
Shout.deleted_at.is_(None)
|
||||||
|
)
|
||||||
|
shouts = query.all()
|
||||||
|
return shouts
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching shouts for search indexing: {e}")
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
if close_session:
|
||||||
|
session.close()
|
|
@ -88,7 +88,7 @@ async def notify_draft(draft_data, action: str = "publish"):
|
||||||
"subtitle": getattr(draft_data, "subtitle", None),
|
"subtitle": getattr(draft_data, "subtitle", None),
|
||||||
"media": getattr(draft_data, "media", None),
|
"media": getattr(draft_data, "media", None),
|
||||||
"created_at": getattr(draft_data, "created_at", None),
|
"created_at": getattr(draft_data, "created_at", None),
|
||||||
"updated_at": getattr(draft_data, "updated_at", None),
|
"updated_at": getattr(draft_data, "updated_at", None)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Если переданы связанные атрибуты, добавим их
|
# Если переданы связанные атрибуты, добавим их
|
||||||
|
@ -100,12 +100,7 @@ async def notify_draft(draft_data, action: str = "publish"):
|
||||||
|
|
||||||
if hasattr(draft_data, "authors") and draft_data.authors is not None:
|
if hasattr(draft_data, "authors") and draft_data.authors is not None:
|
||||||
draft_payload["authors"] = [
|
draft_payload["authors"] = [
|
||||||
{
|
{"id": a.id, "name": a.name, "slug": a.slug, "pic": getattr(a, "pic", None)}
|
||||||
"id": a.id,
|
|
||||||
"name": a.name,
|
|
||||||
"slug": a.slug,
|
|
||||||
"pic": getattr(a, "pic", None),
|
|
||||||
}
|
|
||||||
for a in draft_data.authors
|
for a in draft_data.authors
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
{
|
|
||||||
"include": ["."],
|
|
||||||
"exclude": ["**/node_modules", "**/__pycache__", "**/.*"],
|
|
||||||
"defineConstant": {
|
|
||||||
"DEBUG": true
|
|
||||||
},
|
|
||||||
"venvPath": ".",
|
|
||||||
"venv": ".venv",
|
|
||||||
"pythonVersion": "3.11",
|
|
||||||
"typeCheckingMode": "strict",
|
|
||||||
"reportMissingImports": true,
|
|
||||||
"reportMissingTypeStubs": false,
|
|
||||||
"reportUnknownMemberType": false,
|
|
||||||
"reportUnknownParameterType": false,
|
|
||||||
"reportUnknownVariableType": false,
|
|
||||||
"reportUnknownArgumentType": false,
|
|
||||||
"reportPrivateUsage": false,
|
|
||||||
"reportUntypedFunctionDecorator": false
|
|
||||||
}
|
|
|
@ -29,12 +29,19 @@ async def request_graphql_data(gql, url=AUTH_URL, headers=None):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(url, json=gql, headers=headers)
|
response = await client.post(url, json=gql, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
# Check if the response has content before parsing
|
||||||
|
if response.content and len(response.content.strip()) > 0:
|
||||||
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
errors = data.get("errors")
|
errors = data.get("errors")
|
||||||
if errors:
|
if errors:
|
||||||
logger.error(f"{url} response: {data}")
|
logger.error(f"{url} response: {data}")
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.error(f"JSON decode error: {json_err}, Response content: {response.text[:100]}")
|
||||||
|
else:
|
||||||
|
logger.error(f"{url}: Response is empty")
|
||||||
else:
|
else:
|
||||||
logger.error(f"{url}: {response.status_code} {response.text}")
|
logger.error(f"{url}: {response.status_code} {response.text}")
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
|
|
1070
services/search.py
1070
services/search.py
File diff suppressed because it is too large
Load Diff
|
@ -1,25 +0,0 @@
|
||||||
import pytest
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def oauth_settings() -> Dict[str, Dict[str, str]]:
|
|
||||||
"""Тестовые настройки OAuth"""
|
|
||||||
return {
|
|
||||||
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
|
||||||
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
|
||||||
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def frontend_url() -> str:
|
|
||||||
"""URL фронтенда для тестов"""
|
|
||||||
return "https://localhost:3000"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def mock_settings(monkeypatch, oauth_settings, frontend_url):
|
|
||||||
"""Подменяем настройки для тестов"""
|
|
||||||
monkeypatch.setattr("auth.oauth.OAUTH_CLIENTS", oauth_settings)
|
|
||||||
monkeypatch.setattr("auth.oauth.FRONTEND_URL", frontend_url)
|
|
|
@ -1,222 +0,0 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from starlette.responses import JSONResponse, RedirectResponse
|
|
||||||
|
|
||||||
from auth.oauth import get_user_profile, oauth_login, oauth_callback
|
|
||||||
|
|
||||||
# Подменяем настройки для тестов
|
|
||||||
with (
|
|
||||||
patch("auth.oauth.FRONTEND_URL", "https://localhost:3000"),
|
|
||||||
patch(
|
|
||||||
"auth.oauth.OAUTH_CLIENTS",
|
|
||||||
{
|
|
||||||
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
|
||||||
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
|
||||||
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
):
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_request():
|
|
||||||
"""Фикстура для мока запроса"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.session = {}
|
|
||||||
request.path_params = {}
|
|
||||||
request.query_params = {}
|
|
||||||
return request
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_oauth_client():
|
|
||||||
"""Фикстура для мока OAuth клиента"""
|
|
||||||
client = AsyncMock()
|
|
||||||
client.authorize_redirect = AsyncMock()
|
|
||||||
client.authorize_access_token = AsyncMock()
|
|
||||||
client.get = AsyncMock()
|
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_user_profile_google():
|
|
||||||
"""Тест получения профиля из Google"""
|
|
||||||
client = AsyncMock()
|
|
||||||
token = {
|
|
||||||
"userinfo": {
|
|
||||||
"sub": "123",
|
|
||||||
"email": "test@gmail.com",
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://lh3.googleusercontent.com/photo=s96",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = await get_user_profile("google", client, token)
|
|
||||||
|
|
||||||
assert profile["id"] == "123"
|
|
||||||
assert profile["email"] == "test@gmail.com"
|
|
||||||
assert profile["name"] == "Test User"
|
|
||||||
assert profile["picture"] == "https://lh3.googleusercontent.com/photo=s600"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_user_profile_github():
|
|
||||||
"""Тест получения профиля из GitHub"""
|
|
||||||
client = AsyncMock()
|
|
||||||
client.get.side_effect = [
|
|
||||||
MagicMock(
|
|
||||||
json=lambda: {
|
|
||||||
"id": 456,
|
|
||||||
"login": "testuser",
|
|
||||||
"name": "Test User",
|
|
||||||
"avatar_url": "https://github.com/avatar.jpg",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
MagicMock(
|
|
||||||
json=lambda: [
|
|
||||||
{"email": "other@github.com", "primary": False},
|
|
||||||
{"email": "test@github.com", "primary": True},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
profile = await get_user_profile("github", client, {})
|
|
||||||
|
|
||||||
assert profile["id"] == "456"
|
|
||||||
assert profile["email"] == "test@github.com"
|
|
||||||
assert profile["name"] == "Test User"
|
|
||||||
assert profile["picture"] == "https://github.com/avatar.jpg"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_user_profile_facebook():
|
|
||||||
"""Тест получения профиля из Facebook"""
|
|
||||||
client = AsyncMock()
|
|
||||||
client.get.return_value = MagicMock(
|
|
||||||
json=lambda: {
|
|
||||||
"id": "789",
|
|
||||||
"name": "Test User",
|
|
||||||
"email": "test@facebook.com",
|
|
||||||
"picture": {"data": {"url": "https://facebook.com/photo.jpg"}},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
profile = await get_user_profile("facebook", client, {})
|
|
||||||
|
|
||||||
assert profile["id"] == "789"
|
|
||||||
assert profile["email"] == "test@facebook.com"
|
|
||||||
assert profile["name"] == "Test User"
|
|
||||||
assert profile["picture"] == "https://facebook.com/photo.jpg"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_login_success(mock_request, mock_oauth_client):
|
|
||||||
"""Тест успешного начала OAuth авторизации"""
|
|
||||||
mock_request.path_params["provider"] = "google"
|
|
||||||
|
|
||||||
# Настраиваем мок для authorize_redirect
|
|
||||||
redirect_response = RedirectResponse(url="http://example.com")
|
|
||||||
mock_oauth_client.authorize_redirect.return_value = redirect_response
|
|
||||||
|
|
||||||
with patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client):
|
|
||||||
response = await oauth_login(mock_request)
|
|
||||||
|
|
||||||
assert isinstance(response, RedirectResponse)
|
|
||||||
assert mock_request.session["provider"] == "google"
|
|
||||||
assert "code_verifier" in mock_request.session
|
|
||||||
assert "state" in mock_request.session
|
|
||||||
|
|
||||||
mock_oauth_client.authorize_redirect.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_login_invalid_provider(mock_request):
|
|
||||||
"""Тест с неправильным провайдером"""
|
|
||||||
mock_request.path_params["provider"] = "invalid"
|
|
||||||
|
|
||||||
response = await oauth_login(mock_request)
|
|
||||||
|
|
||||||
assert isinstance(response, JSONResponse)
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid provider" in response.body.decode()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_success(mock_request, mock_oauth_client):
|
|
||||||
"""Тест успешного OAuth callback"""
|
|
||||||
mock_request.session = {
|
|
||||||
"provider": "google",
|
|
||||||
"code_verifier": "test_verifier",
|
|
||||||
"state": "test_state",
|
|
||||||
}
|
|
||||||
mock_request.query_params["state"] = "test_state"
|
|
||||||
|
|
||||||
mock_oauth_client.authorize_access_token.return_value = {
|
|
||||||
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
|
|
||||||
patch("auth.oauth.local_session") as mock_session,
|
|
||||||
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
|
|
||||||
):
|
|
||||||
# Мокаем сессию базы данных
|
|
||||||
session = MagicMock()
|
|
||||||
session.query.return_value.filter.return_value.first.return_value = None
|
|
||||||
mock_session.return_value.__enter__.return_value = session
|
|
||||||
|
|
||||||
response = await oauth_callback(mock_request)
|
|
||||||
|
|
||||||
assert isinstance(response, RedirectResponse)
|
|
||||||
assert response.status_code == 307
|
|
||||||
assert "auth/success" in response.headers["location"]
|
|
||||||
|
|
||||||
# Проверяем cookie
|
|
||||||
cookies = response.headers.getlist("set-cookie")
|
|
||||||
assert any("session_token=test_token" in cookie for cookie in cookies)
|
|
||||||
assert any("httponly" in cookie.lower() for cookie in cookies)
|
|
||||||
assert any("secure" in cookie.lower() for cookie in cookies)
|
|
||||||
|
|
||||||
# Проверяем очистку сессии
|
|
||||||
assert "code_verifier" not in mock_request.session
|
|
||||||
assert "provider" not in mock_request.session
|
|
||||||
assert "state" not in mock_request.session
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_invalid_state(mock_request):
|
|
||||||
"""Тест с неправильным state параметром"""
|
|
||||||
mock_request.session = {"provider": "google", "state": "correct_state"}
|
|
||||||
mock_request.query_params["state"] = "wrong_state"
|
|
||||||
|
|
||||||
response = await oauth_callback(mock_request)
|
|
||||||
|
|
||||||
assert isinstance(response, JSONResponse)
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid state" in response.body.decode()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_existing_user(mock_request, mock_oauth_client):
|
|
||||||
"""Тест OAuth callback с существующим пользователем"""
|
|
||||||
mock_request.session = {
|
|
||||||
"provider": "google",
|
|
||||||
"code_verifier": "test_verifier",
|
|
||||||
"state": "test_state",
|
|
||||||
}
|
|
||||||
mock_request.query_params["state"] = "test_state"
|
|
||||||
|
|
||||||
mock_oauth_client.authorize_access_token.return_value = {
|
|
||||||
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
|
|
||||||
patch("auth.oauth.local_session") as mock_session,
|
|
||||||
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
|
|
||||||
):
|
|
||||||
# Мокаем существующего пользователя
|
|
||||||
existing_user = MagicMock()
|
|
||||||
session = MagicMock()
|
|
||||||
session.query.return_value.filter.return_value.first.return_value = existing_user
|
|
||||||
mock_session.return_value.__enter__.return_value = session
|
|
||||||
|
|
||||||
response = await oauth_callback(mock_request)
|
|
||||||
|
|
||||||
assert isinstance(response, RedirectResponse)
|
|
||||||
assert response.status_code == 307
|
|
||||||
|
|
||||||
# Проверяем обновление существующего пользователя
|
|
||||||
assert existing_user.name == "Test User"
|
|
||||||
assert existing_user.oauth == "google:123"
|
|
||||||
assert existing_user.email_verified is True
|
|
|
@ -1,9 +0,0 @@
|
||||||
"""Тестовые настройки для OAuth"""
|
|
||||||
|
|
||||||
FRONTEND_URL = "https://localhost:3000"
|
|
||||||
|
|
||||||
OAUTH_CLIENTS = {
|
|
||||||
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
|
||||||
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
|
||||||
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
|
||||||
}
|
|
|
@ -1,7 +1,17 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
from services.db import Base
|
||||||
from services.redis import redis
|
from services.redis import redis
|
||||||
from tests.test_config import get_test_client
|
|
||||||
|
# Use SQLite for testing
|
||||||
|
TEST_DB_URL = "sqlite:///test.db"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -13,36 +23,38 @@ def event_loop():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def test_app():
|
def test_engine():
|
||||||
"""Create a test client and session factory."""
|
"""Create a test database engine."""
|
||||||
client, SessionLocal = get_test_client()
|
engine = create_engine(TEST_DB_URL)
|
||||||
return client, SessionLocal
|
Base.metadata.create_all(engine)
|
||||||
|
yield engine
|
||||||
|
Base.metadata.drop_all(engine)
|
||||||
|
os.remove("test.db")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db_session(test_app):
|
def db_session(test_engine):
|
||||||
"""Create a new database session for a test."""
|
"""Create a new database session for a test."""
|
||||||
_, SessionLocal = test_app
|
connection = test_engine.connect()
|
||||||
session = SessionLocal()
|
transaction = connection.begin()
|
||||||
|
session = Session(bind=connection)
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
session.rollback()
|
|
||||||
session.close()
|
session.close()
|
||||||
|
transaction.rollback()
|
||||||
|
connection.close()
|
||||||
@pytest.fixture
|
|
||||||
def test_client(test_app):
|
|
||||||
"""Get the test client."""
|
|
||||||
client, _ = test_app
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def redis_client():
|
async def redis_client():
|
||||||
"""Create a test Redis client."""
|
"""Create a test Redis client."""
|
||||||
await redis.connect()
|
await redis.connect()
|
||||||
await redis.flushall() # Очищаем Redis перед каждым тестом
|
|
||||||
yield redis
|
yield redis
|
||||||
await redis.flushall() # Очищаем после теста
|
|
||||||
await redis.disconnect()
|
await redis.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_client():
|
||||||
|
"""Create a TestClient instance."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
|
@ -1,67 +0,0 @@
|
||||||
"""
|
|
||||||
Конфигурация для тестов
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.middleware import Middleware
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.testclient import TestClient
|
|
||||||
|
|
||||||
# Используем in-memory SQLite для тестов
|
|
||||||
TEST_DB_URL = "sqlite:///:memory:"
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Middleware для внедрения сессии БД"""
|
|
||||||
|
|
||||||
def __init__(self, app, session_maker):
|
|
||||||
super().__init__(app)
|
|
||||||
self.session_maker = session_maker
|
|
||||||
|
|
||||||
async def dispatch(self, request, call_next):
|
|
||||||
session = self.session_maker()
|
|
||||||
request.state.db = session
|
|
||||||
try:
|
|
||||||
response = await call_next(request)
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_app():
|
|
||||||
"""Create a test Starlette application."""
|
|
||||||
from services.db import Base
|
|
||||||
|
|
||||||
# Создаем движок и таблицы
|
|
||||||
engine = create_engine(
|
|
||||||
TEST_DB_URL,
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
echo=False,
|
|
||||||
)
|
|
||||||
Base.metadata.drop_all(bind=engine)
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
# Создаем фабрику сессий
|
|
||||||
SessionLocal = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
# Создаем middleware для сессий
|
|
||||||
middleware = [Middleware(DatabaseMiddleware, session_maker=SessionLocal)]
|
|
||||||
|
|
||||||
# Создаем тестовое приложение
|
|
||||||
app = Starlette(
|
|
||||||
debug=True,
|
|
||||||
middleware=middleware,
|
|
||||||
routes=[], # Здесь можно добавить тестовые маршруты если нужно
|
|
||||||
)
|
|
||||||
|
|
||||||
return app, SessionLocal
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_client():
|
|
||||||
"""Get a test client with initialized database."""
|
|
||||||
app, SessionLocal = create_test_app()
|
|
||||||
return TestClient(app), SessionLocal
|
|
|
@ -53,11 +53,7 @@ async def test_create_reaction(test_client, db_session, test_setup):
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"variables": {
|
"variables": {
|
||||||
"reaction": {
|
"reaction": {"shout": test_setup["shout"].id, "kind": ReactionKind.LIKE.value, "body": "Great post!"}
|
||||||
"shout": test_setup["shout"].id,
|
|
||||||
"kind": ReactionKind.LIKE.value,
|
|
||||||
"body": "Great post!",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
70
tests/test_validations.py
Normal file
70
tests/test_validations.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from auth.validations import (
|
||||||
|
AuthInput,
|
||||||
|
AuthResponse,
|
||||||
|
TokenPayload,
|
||||||
|
UserRegistrationInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthValidations:
|
||||||
|
def test_auth_input(self):
|
||||||
|
"""Test basic auth input validation"""
|
||||||
|
# Valid case
|
||||||
|
auth = AuthInput(user_id="123", username="testuser", token="1234567890abcdef1234567890abcdef")
|
||||||
|
assert auth.user_id == "123"
|
||||||
|
assert auth.username == "testuser"
|
||||||
|
|
||||||
|
# Invalid cases
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthInput(user_id="", username="test", token="x" * 32)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthInput(user_id="123", username="t", token="x" * 32)
|
||||||
|
|
||||||
|
def test_user_registration(self):
|
||||||
|
"""Test user registration validation"""
|
||||||
|
# Valid case
|
||||||
|
user = UserRegistrationInput(email="test@example.com", password="SecurePass123!", name="Test User")
|
||||||
|
assert user.email == "test@example.com"
|
||||||
|
assert user.name == "Test User"
|
||||||
|
|
||||||
|
# Test email validation
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
UserRegistrationInput(email="invalid-email", password="SecurePass123!", name="Test")
|
||||||
|
assert "Invalid email format" in str(exc.value)
|
||||||
|
|
||||||
|
# Test password validation
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
UserRegistrationInput(email="test@example.com", password="weak", name="Test")
|
||||||
|
assert "String should have at least 8 characters" in str(exc.value)
|
||||||
|
|
||||||
|
def test_token_payload(self):
|
||||||
|
"""Test token payload validation"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
exp = now + timedelta(hours=1)
|
||||||
|
|
||||||
|
payload = TokenPayload(user_id="123", username="testuser", exp=exp, iat=now)
|
||||||
|
assert payload.user_id == "123"
|
||||||
|
assert payload.username == "testuser"
|
||||||
|
assert payload.scopes == [] # Default empty list
|
||||||
|
|
||||||
|
def test_auth_response(self):
|
||||||
|
"""Test auth response validation"""
|
||||||
|
# Success case
|
||||||
|
success_resp = AuthResponse(success=True, token="valid_token", user={"id": "123", "name": "Test"})
|
||||||
|
assert success_resp.success is True
|
||||||
|
assert success_resp.token == "valid_token"
|
||||||
|
|
||||||
|
# Error case
|
||||||
|
error_resp = AuthResponse(success=False, error="Invalid credentials")
|
||||||
|
assert error_resp.success is False
|
||||||
|
assert error_resp.error == "Invalid credentials"
|
||||||
|
|
||||||
|
# Invalid case - отсутствует обязательное поле token при success=True
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthResponse(success=True, user={"id": "123", "name": "Test"})
|
Loading…
Reference in New Issue
Block a user