This commit is contained in:
parent
37a9a284ef
commit
4a835bbfba
|
@ -2,7 +2,11 @@
|
||||||
- `Shout.draft` field added
|
- `Shout.draft` field added
|
||||||
- `Draft` entity added
|
- `Draft` entity added
|
||||||
- `create_draft`, `update_draft`, `delete_draft` mutations and resolvers added
|
- `create_draft`, `update_draft`, `delete_draft` mutations and resolvers added
|
||||||
- `get_shout_drafts` resolver updated
|
- `create_shout`, `update_shout`, `delete_shout` mutations removed from GraphQL API
|
||||||
|
- `load_drafts` resolver implemented
|
||||||
|
- `publish_` and `unpublish_` mutations and resolvers added
|
||||||
|
- `create_`, `update_`, `delete_` mutations and resolvers added for `Draft` entity
|
||||||
|
- tests with pytest for auth, shouts, drafts
|
||||||
|
|
||||||
#### [0.4.8] - 2025-02-03
|
#### [0.4.8] - 2025-02-03
|
||||||
- `Reaction.deleted_at` filter on `update_reaction` resolver added
|
- `Reaction.deleted_at` filter on `update_reaction` resolver added
|
||||||
|
|
|
@ -2,13 +2,13 @@ from binascii import hexlify
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
# from base.exceptions import InvalidPassword, InvalidToken
|
# from base.exceptions import InvalidPassword, InvalidToken
|
||||||
from base.orm import local_session
|
from services.db import local_session
|
||||||
from jwt import DecodeError, ExpiredSignatureError
|
from auth.exceptions import ExpiredToken, InvalidToken
|
||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
from auth.jwtcodec import JWTCodec
|
from auth.jwtcodec import JWTCodec
|
||||||
from auth.tokenstorage import TokenStorage
|
from auth.tokenstorage import TokenStorage
|
||||||
from orm import User
|
from orm.user import User
|
||||||
|
|
||||||
|
|
||||||
class Password:
|
class Password:
|
||||||
|
@ -79,10 +79,10 @@ class Identity:
|
||||||
if not await TokenStorage.exist(f"{payload.user_id}-{payload.username}-{token}"):
|
if not await TokenStorage.exist(f"{payload.user_id}-{payload.username}-{token}"):
|
||||||
# raise InvalidToken("Login token has expired, please login again")
|
# raise InvalidToken("Login token has expired, please login again")
|
||||||
return {"error": "Token has expired"}
|
return {"error": "Token has expired"}
|
||||||
except ExpiredSignatureError:
|
except ExpiredToken:
|
||||||
# raise InvalidToken("Login token has expired, please try again")
|
# raise InvalidToken("Login token has expired, please try again")
|
||||||
return {"error": "Token has expired"}
|
return {"error": "Token has expired"}
|
||||||
except DecodeError:
|
except InvalidToken:
|
||||||
# raise InvalidToken("token format error") from e
|
# raise InvalidToken("token format error") from e
|
||||||
return {"error": "Token format error"}
|
return {"error": "Token format error"}
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from base.redis import redis
|
from services.redis import redis
|
||||||
from validations.auth import AuthInput
|
from auth.validations import AuthInput
|
||||||
|
|
||||||
from auth.jwtcodec import JWTCodec
|
from auth.jwtcodec import JWTCodec
|
||||||
from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN
|
from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN
|
||||||
|
|
103
auth/validations.py
Normal file
103
auth/validations.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
# RFC 5322 compliant email regex pattern
|
||||||
|
EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||||
|
|
||||||
|
class AuthInput(BaseModel):
|
||||||
|
"""Base model for authentication input validation"""
|
||||||
|
user_id: str = Field(description="Unique user identifier")
|
||||||
|
username: str = Field(min_length=2, max_length=50)
|
||||||
|
token: str = Field(min_length=32)
|
||||||
|
|
||||||
|
@field_validator('user_id')
|
||||||
|
@classmethod
|
||||||
|
def validate_user_id(cls, v: str) -> str:
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("user_id cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class UserRegistrationInput(BaseModel):
|
||||||
|
"""Validation model for user registration"""
|
||||||
|
email: str = Field(max_length=254) # Max email length per RFC 5321
|
||||||
|
password: str = Field(min_length=8, max_length=100)
|
||||||
|
name: str = Field(min_length=2, max_length=50)
|
||||||
|
|
||||||
|
@field_validator('email')
|
||||||
|
@classmethod
|
||||||
|
def validate_email(cls, v: str) -> str:
|
||||||
|
"""Validate email format"""
|
||||||
|
if not re.match(EMAIL_PATTERN, v):
|
||||||
|
raise ValueError("Invalid email format")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
@field_validator('password')
|
||||||
|
@classmethod
|
||||||
|
def validate_password_strength(cls, v: str) -> str:
|
||||||
|
"""Validate password meets security requirements"""
|
||||||
|
if not any(c.isupper() for c in v):
|
||||||
|
raise ValueError("Password must contain at least one uppercase letter")
|
||||||
|
if not any(c.islower() for c in v):
|
||||||
|
raise ValueError("Password must contain at least one lowercase letter")
|
||||||
|
if not any(c.isdigit() for c in v):
|
||||||
|
raise ValueError("Password must contain at least one number")
|
||||||
|
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in v):
|
||||||
|
raise ValueError("Password must contain at least one special character")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class UserLoginInput(BaseModel):
|
||||||
|
"""Validation model for user login"""
|
||||||
|
email: str = Field(max_length=254)
|
||||||
|
password: str = Field(min_length=8, max_length=100)
|
||||||
|
|
||||||
|
@field_validator('email')
|
||||||
|
@classmethod
|
||||||
|
def validate_email(cls, v: str) -> str:
|
||||||
|
if not re.match(EMAIL_PATTERN, v):
|
||||||
|
raise ValueError("Invalid email format")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
"""Validation model for JWT token payload"""
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
exp: datetime
|
||||||
|
iat: datetime
|
||||||
|
scopes: Optional[List[str]] = []
|
||||||
|
|
||||||
|
class OAuthInput(BaseModel):
|
||||||
|
"""Validation model for OAuth input"""
|
||||||
|
provider: str = Field(pattern='^(google|github|facebook)$')
|
||||||
|
code: str
|
||||||
|
redirect_uri: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator('provider')
|
||||||
|
@classmethod
|
||||||
|
def validate_provider(cls, v: str) -> str:
|
||||||
|
valid_providers = ['google', 'github', 'facebook']
|
||||||
|
if v.lower() not in valid_providers:
|
||||||
|
raise ValueError(f"Provider must be one of: {', '.join(valid_providers)}")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
class AuthResponse(BaseModel):
|
||||||
|
"""Validation model for authentication responses"""
|
||||||
|
success: bool
|
||||||
|
token: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
user: Optional[Dict[str, Union[str, int, bool]]] = None
|
||||||
|
|
||||||
|
@field_validator('error')
|
||||||
|
@classmethod
|
||||||
|
def validate_error_if_not_success(cls, v: Optional[str], info) -> Optional[str]:
|
||||||
|
if not info.data.get('success') and not v:
|
||||||
|
raise ValueError("Error message required when success is False")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('token')
|
||||||
|
@classmethod
|
||||||
|
def validate_token_if_success(cls, v: Optional[str], info) -> Optional[str]:
|
||||||
|
if info.data.get('success') and not v:
|
||||||
|
raise ValueError("Token required when success is True")
|
||||||
|
return v
|
176
orm/rbac.py
Normal file
176
orm/rbac.py
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
from services.db import REGISTRY, Base, local_session
|
||||||
|
from utils.logger import root_logger as logger
|
||||||
|
|
||||||
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
from sqlalchemy.types import String
|
||||||
|
from sqlalchemy import Column, ForeignKey, String, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
class ClassType(TypeDecorator):
|
||||||
|
impl = String
|
||||||
|
|
||||||
|
@property
|
||||||
|
def python_type(self):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def process_literal_param(self, value, dialect):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return value.__name__ if isinstance(value, type) else str(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
class_ = REGISTRY.get(value)
|
||||||
|
if class_ is None:
|
||||||
|
logger.warn(f"Can't find class <{value}>,find it yourself!", stacklevel=2)
|
||||||
|
return class_
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Base):
|
||||||
|
__tablename__ = "role"
|
||||||
|
|
||||||
|
name = Column(String, nullable=False, comment="Role Name")
|
||||||
|
desc = Column(String, nullable=True, comment="Role Description")
|
||||||
|
community = Column(
|
||||||
|
ForeignKey("community.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
comment="Community",
|
||||||
|
)
|
||||||
|
permissions = relationship(lambda: Permission)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
with local_session() as session:
|
||||||
|
r = session.query(Role).filter(Role.name == "author").first()
|
||||||
|
if r:
|
||||||
|
Role.default_role = r
|
||||||
|
return
|
||||||
|
|
||||||
|
r1 = Role.create(
|
||||||
|
name="author",
|
||||||
|
desc="Role for an author",
|
||||||
|
community=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(r1)
|
||||||
|
|
||||||
|
Role.default_role = r1
|
||||||
|
|
||||||
|
r2 = Role.create(
|
||||||
|
name="reader",
|
||||||
|
desc="Role for a reader",
|
||||||
|
community=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(r2)
|
||||||
|
|
||||||
|
r3 = Role.create(
|
||||||
|
name="expert",
|
||||||
|
desc="Role for an expert",
|
||||||
|
community=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(r3)
|
||||||
|
|
||||||
|
r4 = Role.create(
|
||||||
|
name="editor",
|
||||||
|
desc="Role for an editor",
|
||||||
|
community=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(r4)
|
||||||
|
|
||||||
|
|
||||||
|
class Operation(Base):
|
||||||
|
__tablename__ = "operation"
|
||||||
|
name = Column(String, nullable=False, unique=True, comment="Operation Name")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
with local_session() as session:
|
||||||
|
for name in ["create", "update", "delete", "load"]:
|
||||||
|
"""
|
||||||
|
* everyone can:
|
||||||
|
- load shouts
|
||||||
|
- load topics
|
||||||
|
- load reactions
|
||||||
|
- create an account to become a READER
|
||||||
|
* readers can:
|
||||||
|
- update and delete their account
|
||||||
|
- load chats
|
||||||
|
- load messages
|
||||||
|
- create reaction of some shout's author allowed kinds
|
||||||
|
- create shout to become an AUTHOR
|
||||||
|
* authors can:
|
||||||
|
- update and delete their shout
|
||||||
|
- invite other authors to edit shout and chat
|
||||||
|
- manage allowed reactions for their shout
|
||||||
|
* pros can:
|
||||||
|
- create/update/delete their community
|
||||||
|
- create/update/delete topics for their community
|
||||||
|
|
||||||
|
"""
|
||||||
|
op = session.query(Operation).filter(Operation.name == name).first()
|
||||||
|
if not op:
|
||||||
|
op = Operation.create(name=name)
|
||||||
|
session.add(op)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class Resource(Base):
|
||||||
|
__tablename__ = "resource"
|
||||||
|
resourceClass = Column(String, nullable=False, unique=True, comment="Resource class")
|
||||||
|
name = Column(String, nullable=False, unique=True, comment="Resource name")
|
||||||
|
# TODO: community = Column(ForeignKey())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
with local_session() as session:
|
||||||
|
for res in [
|
||||||
|
"shout",
|
||||||
|
"topic",
|
||||||
|
"reaction",
|
||||||
|
"chat",
|
||||||
|
"message",
|
||||||
|
"invite",
|
||||||
|
"community",
|
||||||
|
"user",
|
||||||
|
]:
|
||||||
|
r = session.query(Resource).filter(Resource.name == res).first()
|
||||||
|
if not r:
|
||||||
|
r = Resource.create(name=res, resourceClass=res)
|
||||||
|
session.add(r)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class Permission(Base):
|
||||||
|
__tablename__ = "permission"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("role", "operation", "resource"),
|
||||||
|
{"extend_existing": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
role: Column = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role")
|
||||||
|
operation: Column = Column(
|
||||||
|
ForeignKey("operation.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
comment="Operation",
|
||||||
|
)
|
||||||
|
resource: Column = Column(
|
||||||
|
ForeignKey("resource.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
comment="Resource",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# Base.metadata.create_all(engine)
|
||||||
|
# ops = [
|
||||||
|
# Permission(role=1, operation=1, resource=1),
|
||||||
|
# Permission(role=1, operation=2, resource=1),
|
||||||
|
# Permission(role=1, operation=3, resource=1),
|
||||||
|
# Permission(role=1, operation=4, resource=1),
|
||||||
|
# Permission(role=2, operation=4, resource=1),
|
||||||
|
# ]
|
||||||
|
# global_session.add_all(ops)
|
||||||
|
# global_session.commit()
|
105
orm/user.py
Normal file
105
orm/user.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
from sqlalchemy import JSON as JSONType
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from services.db import Base, local_session
|
||||||
|
from orm.rbac import Role
|
||||||
|
|
||||||
|
|
||||||
|
class UserRating(Base):
|
||||||
|
__tablename__ = "user_rating"
|
||||||
|
|
||||||
|
id = None
|
||||||
|
rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
value: Column = Column(Integer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(Base):
|
||||||
|
__tablename__ = "user_role"
|
||||||
|
|
||||||
|
id = None
|
||||||
|
user = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
role = Column(ForeignKey("role.id"), primary_key=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorFollower(Base):
|
||||||
|
__tablename__ = "author_follower"
|
||||||
|
|
||||||
|
id = None
|
||||||
|
follower: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
author: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
||||||
|
createdAt = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at"
|
||||||
|
)
|
||||||
|
auto = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
default_user = None
|
||||||
|
|
||||||
|
email = Column(String, unique=True, nullable=False, comment="Email")
|
||||||
|
username = Column(String, nullable=False, comment="Login")
|
||||||
|
password = Column(String, nullable=True, comment="Password")
|
||||||
|
bio = Column(String, nullable=True, comment="Bio") # status description
|
||||||
|
about = Column(String, nullable=True, comment="About") # long and formatted
|
||||||
|
userpic = Column(String, nullable=True, comment="Userpic")
|
||||||
|
name = Column(String, nullable=True, comment="Display name")
|
||||||
|
slug = Column(String, unique=True, comment="User's slug")
|
||||||
|
muted = Column(Boolean, default=False)
|
||||||
|
emailConfirmed = Column(Boolean, default=False)
|
||||||
|
createdAt = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at"
|
||||||
|
)
|
||||||
|
lastSeen = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at"
|
||||||
|
)
|
||||||
|
deletedAt = Column(DateTime(timezone=True), nullable=True, comment="Deleted at")
|
||||||
|
links = Column(JSONType, nullable=True, comment="Links")
|
||||||
|
oauth = Column(String, nullable=True)
|
||||||
|
ratings = relationship(UserRating, foreign_keys=UserRating.user)
|
||||||
|
roles = relationship(lambda: Role, secondary=UserRole.__tablename__)
|
||||||
|
oid = Column(String, nullable=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_table():
|
||||||
|
with local_session() as session:
|
||||||
|
default = session.query(User).filter(User.slug == "anonymous").first()
|
||||||
|
if not default:
|
||||||
|
default_dict = {
|
||||||
|
"email": "noreply@discours.io",
|
||||||
|
"username": "noreply@discours.io",
|
||||||
|
"name": "Аноним",
|
||||||
|
"slug": "anonymous",
|
||||||
|
}
|
||||||
|
default = User.create(**default_dict)
|
||||||
|
session.add(default)
|
||||||
|
discours_dict = {
|
||||||
|
"email": "welcome@discours.io",
|
||||||
|
"username": "welcome@discours.io",
|
||||||
|
"name": "Дискурс",
|
||||||
|
"slug": "discours",
|
||||||
|
}
|
||||||
|
discours = User.create(**discours_dict)
|
||||||
|
session.add(discours)
|
||||||
|
session.commit()
|
||||||
|
User.default_user = default
|
||||||
|
|
||||||
|
def get_permission(self):
|
||||||
|
scope = {}
|
||||||
|
for role in self.roles:
|
||||||
|
for p in role.permissions:
|
||||||
|
if p.resource not in scope:
|
||||||
|
scope[p.resource] = set()
|
||||||
|
scope[p.resource].add(p.operation)
|
||||||
|
print(scope)
|
||||||
|
return scope
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print(User.get_permission(user_id=1))
|
|
@ -26,6 +26,8 @@ fakeredis = "^2.25.1"
|
||||||
pydantic = "^2.9.2"
|
pydantic = "^2.9.2"
|
||||||
jwt = "^1.3.1"
|
jwt = "^1.3.1"
|
||||||
authlib = "^1.3.2"
|
authlib = "^1.3.2"
|
||||||
|
passlib = "^1.7.4"
|
||||||
|
bcrypt = "^4.2.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
@ -34,13 +36,15 @@ isort = "^5.13.2"
|
||||||
pydantic = "^2.9.2"
|
pydantic = "^2.9.2"
|
||||||
pytest = "^8.3.4"
|
pytest = "^8.3.4"
|
||||||
mypy = "^1.15.0"
|
mypy = "^1.15.0"
|
||||||
|
pytest-asyncio = "^0.23.5"
|
||||||
|
pytest-cov = "^4.1.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
venvPath = "."
|
venvPath = "venv"
|
||||||
venv = "venv"
|
venv = "venv"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
|
@ -49,5 +53,10 @@ include_trailing_comma = true
|
||||||
force_grid_wrap = 0
|
force_grid_wrap = 0
|
||||||
line_length = 120
|
line_length = 120
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
|
venv = "venv"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
|
@ -1,16 +1,21 @@
|
||||||
import time
|
import time
|
||||||
from importlib import invalidate_caches
|
from orm.topic import Topic
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.sql import and_
|
||||||
|
|
||||||
from cache.cache import invalidate_shout_related_cache, invalidate_shouts_cache
|
from cache.cache import (
|
||||||
|
cache_author, cache_by_id, cache_topic,
|
||||||
|
invalidate_shout_related_cache, invalidate_shouts_cache
|
||||||
|
)
|
||||||
from orm.author import Author
|
from orm.author import Author
|
||||||
from orm.draft import Draft
|
from orm.draft import Draft
|
||||||
from orm.shout import Shout
|
from orm.shout import Shout, ShoutAuthor, ShoutTopic
|
||||||
from services.auth import login_required
|
from services.auth import login_required
|
||||||
from services.db import local_session
|
from services.db import local_session
|
||||||
from services.schema import mutation, query
|
from services.schema import mutation, query
|
||||||
from utils.logger import root_logger as logger
|
from utils.logger import root_logger as logger
|
||||||
|
from services.notify import notify_shout
|
||||||
|
from services.search import search_service
|
||||||
|
|
||||||
|
|
||||||
@query.field("load_drafts")
|
@query.field("load_drafts")
|
||||||
|
@ -119,9 +124,10 @@ async def unpublish_draft(_, info, draft_id: int):
|
||||||
@login_required
|
@login_required
|
||||||
async def publish_shout(_, info, shout_id: int, draft=None):
|
async def publish_shout(_, info, shout_id: int, draft=None):
|
||||||
"""Publish draft as a shout or update existing shout.
|
"""Publish draft as a shout or update existing shout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: SQLAlchemy session to use for database operations
|
shout_id: ID существующей публикации или 0 для новой
|
||||||
|
draft: Объект черновика (опционально)
|
||||||
"""
|
"""
|
||||||
user_id = info.context.get("user_id")
|
user_id = info.context.get("user_id")
|
||||||
author_dict = info.context.get("author", {})
|
author_dict = info.context.get("author", {})
|
||||||
|
@ -130,16 +136,25 @@ async def publish_shout(_, info, shout_id: int, draft=None):
|
||||||
return {"error": "User ID and author ID are required"}
|
return {"error": "User ID and author ID are required"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use proper SQLAlchemy query
|
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
|
# Находим черновик если не передан
|
||||||
if not draft:
|
if not draft:
|
||||||
find_draft_stmt = select(Draft).where(Draft.shout == shout_id)
|
find_draft_stmt = select(Draft).where(Draft.shout == shout_id)
|
||||||
draft = session.execute(find_draft_stmt).scalar_one_or_none()
|
draft = session.execute(find_draft_stmt).scalar_one_or_none()
|
||||||
|
if not draft:
|
||||||
|
return {"error": "Draft not found"}
|
||||||
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
|
|
||||||
|
# Находим существующую публикацию или создаем новую
|
||||||
|
shout = None
|
||||||
|
was_published = False
|
||||||
|
if shout_id:
|
||||||
|
shout = session.query(Shout).filter(Shout.id == shout_id).first()
|
||||||
|
was_published = shout and shout.published_at is not None
|
||||||
|
|
||||||
if not shout:
|
if not shout:
|
||||||
# Create new shout from draft
|
# Создаем новую публикацию
|
||||||
shout = Shout(
|
shout = Shout(
|
||||||
body=draft.body,
|
body=draft.body,
|
||||||
slug=draft.slug,
|
slug=draft.slug,
|
||||||
|
@ -155,15 +170,11 @@ async def publish_shout(_, info, shout_id: int, draft=None):
|
||||||
seo=draft.seo,
|
seo=draft.seo,
|
||||||
created_by=author_id,
|
created_by=author_id,
|
||||||
community=draft.community,
|
community=draft.community,
|
||||||
authors=draft.authors.copy(), # Create copies of relationships
|
|
||||||
topics=draft.topics.copy(),
|
|
||||||
draft=draft.id,
|
draft=draft.id,
|
||||||
deleted_at=None,
|
deleted_at=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Update existing shout
|
# Обновляем существующую публикацию
|
||||||
shout.authors = draft.authors.copy()
|
|
||||||
shout.topics = draft.topics.copy()
|
|
||||||
shout.draft = draft.id
|
shout.draft = draft.id
|
||||||
shout.created_by = author_id
|
shout.created_by = author_id
|
||||||
shout.title = draft.title
|
shout.title = draft.title
|
||||||
|
@ -178,24 +189,78 @@ async def publish_shout(_, info, shout_id: int, draft=None):
|
||||||
shout.lang = draft.lang
|
shout.lang = draft.lang
|
||||||
shout.seo = draft.seo
|
shout.seo = draft.seo
|
||||||
|
|
||||||
|
# Обновляем временные метки
|
||||||
shout.updated_at = now
|
shout.updated_at = now
|
||||||
shout.published_at = now
|
|
||||||
|
# Устанавливаем published_at только если это новая публикация
|
||||||
|
# или публикация была ранее снята с публикации
|
||||||
|
if not was_published:
|
||||||
|
shout.published_at = now
|
||||||
|
|
||||||
draft.updated_at = now
|
draft.updated_at = now
|
||||||
draft.published_at = now
|
draft.published_at = now
|
||||||
|
|
||||||
|
# Обрабатываем связи с авторами
|
||||||
|
if not session.query(ShoutAuthor).filter(
|
||||||
|
and_(ShoutAuthor.shout == shout.id, ShoutAuthor.author == author_id)
|
||||||
|
).first():
|
||||||
|
sa = ShoutAuthor(shout=shout.id, author=author_id)
|
||||||
|
session.add(sa)
|
||||||
|
|
||||||
|
# Обрабатываем темы
|
||||||
|
if draft.topics:
|
||||||
|
for topic in draft.topics:
|
||||||
|
st = ShoutTopic(
|
||||||
|
topic=topic.id,
|
||||||
|
shout=shout.id,
|
||||||
|
main=topic.main if hasattr(topic, 'main') else False
|
||||||
|
)
|
||||||
|
session.add(st)
|
||||||
|
|
||||||
session.add(shout)
|
session.add(shout)
|
||||||
session.add(draft)
|
session.add(draft)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
# Инвалидируем кэш только если это новая публикация или была снята с публикации
|
||||||
|
if not was_published:
|
||||||
|
cache_keys = [
|
||||||
|
"feed",
|
||||||
|
f"author_{author_id}",
|
||||||
|
"random_top",
|
||||||
|
"unrated"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Добавляем ключи для тем
|
||||||
|
for topic in shout.topics:
|
||||||
|
cache_keys.append(f"topic_{topic.id}")
|
||||||
|
cache_keys.append(f"topic_shouts_{topic.id}")
|
||||||
|
await cache_by_id(Topic, topic.id, cache_topic)
|
||||||
|
|
||||||
|
# Инвалидируем кэш
|
||||||
|
await invalidate_shouts_cache(cache_keys)
|
||||||
|
await invalidate_shout_related_cache(shout, author_id)
|
||||||
|
|
||||||
|
# Обновляем кэш авторов
|
||||||
|
for author in shout.authors:
|
||||||
|
await cache_by_id(Author, author.id, cache_author)
|
||||||
|
|
||||||
|
# Отправляем уведомление о публикации
|
||||||
|
await notify_shout(shout.dict(), "published")
|
||||||
|
|
||||||
|
# Обновляем поисковый индекс
|
||||||
|
search_service.index(shout)
|
||||||
|
else:
|
||||||
|
# Для уже опубликованных материалов просто отправляем уведомление об обновлении
|
||||||
|
await notify_shout(shout.dict(), "update")
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
return {"shout": shout}
|
||||||
|
|
||||||
invalidate_shout_related_cache(shout)
|
|
||||||
invalidate_shouts_cache()
|
|
||||||
return {"shout": shout}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
logger.error(f"Failed to publish shout: {e}", exc_info=True)
|
||||||
|
if 'session' in locals():
|
||||||
logger.error(f"Failed to publish shout: {e}")
|
session.rollback()
|
||||||
logger.error(traceback.format_exc())
|
return {"error": f"Failed to publish shout: {str(e)}"}
|
||||||
session.rollback()
|
|
||||||
return {"error": "Failed to publish shout"}
|
|
||||||
|
|
||||||
|
|
||||||
@mutation.field("unpublish_shout")
|
@mutation.field("unpublish_shout")
|
||||||
|
|
|
@ -188,6 +188,8 @@ type Topic {
|
||||||
|
|
||||||
type CommonResult {
|
type CommonResult {
|
||||||
error: String
|
error: String
|
||||||
|
drafts: [Draft]
|
||||||
|
draft: Draft
|
||||||
slugs: [String]
|
slugs: [String]
|
||||||
shout: Shout
|
shout: Shout
|
||||||
shouts: [Shout]
|
shouts: [Shout]
|
||||||
|
|
|
@ -7,8 +7,7 @@ from typing import Any, Callable, Dict, TypeVar
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect
|
from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import Session, configure_mappers, declarative_base
|
||||||
from sqlalchemy.orm import Session, configure_mappers
|
|
||||||
from sqlalchemy.sql.schema import Table
|
from sqlalchemy.sql.schema import Table
|
||||||
|
|
||||||
from settings import DB_URL
|
from settings import DB_URL
|
||||||
|
|
179
services/pretopic.py
Normal file
179
services/pretopic.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import concurrent.futures
|
||||||
|
from typing import Dict, Tuple, List
|
||||||
|
from txtai.embeddings import Embeddings
|
||||||
|
from services.logger import root_logger as logger
|
||||||
|
|
||||||
|
class TopicClassifier:
|
||||||
|
def __init__(self, shouts_by_topic: Dict[str, str], publications: List[Dict[str, str]]):
|
||||||
|
"""
|
||||||
|
Инициализация классификатора тем и поиска публикаций.
|
||||||
|
Args:
|
||||||
|
shouts_by_topic: Словарь {тема: текст_всех_публикаций}
|
||||||
|
publications: Список публикаций с полями 'id', 'title', 'text'
|
||||||
|
"""
|
||||||
|
self.shouts_by_topic = shouts_by_topic
|
||||||
|
self.topics = list(shouts_by_topic.keys())
|
||||||
|
self.publications = publications
|
||||||
|
self.topic_embeddings = None # Для классификации тем
|
||||||
|
self.search_embeddings = None # Для поиска публикаций
|
||||||
|
self._initialization_future = None
|
||||||
|
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
"""
|
||||||
|
Асинхронная инициализация векторных представлений.
|
||||||
|
"""
|
||||||
|
if self._initialization_future is None:
|
||||||
|
self._initialization_future = self._executor.submit(self._prepare_embeddings)
|
||||||
|
logger.info("Векторизация текстов начата в фоновом режиме...")
|
||||||
|
|
||||||
|
def _prepare_embeddings(self) -> None:
|
||||||
|
"""
|
||||||
|
Подготавливает векторные представления для тем и поиска.
|
||||||
|
"""
|
||||||
|
logger.info("Начинается подготовка векторных представлений...")
|
||||||
|
|
||||||
|
# Модель для русского языка
|
||||||
|
# TODO: model local caching
|
||||||
|
model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
|
|
||||||
|
# Инициализируем embeddings для классификации тем
|
||||||
|
self.topic_embeddings = Embeddings(path=model_path)
|
||||||
|
topic_documents = [
|
||||||
|
(topic, text)
|
||||||
|
for topic, text in self.shouts_by_topic.items()
|
||||||
|
]
|
||||||
|
self.topic_embeddings.index(topic_documents)
|
||||||
|
|
||||||
|
# Инициализируем embeddings для поиска публикаций
|
||||||
|
self.search_embeddings = Embeddings(path=model_path)
|
||||||
|
search_documents = [
|
||||||
|
(str(pub['id']), f"{pub['title']} {pub['text']}")
|
||||||
|
for pub in self.publications
|
||||||
|
]
|
||||||
|
self.search_embeddings.index(search_documents)
|
||||||
|
|
||||||
|
logger.info("Подготовка векторных представлений завершена.")
|
||||||
|
|
||||||
|
def predict_topic(self, text: str) -> Tuple[float, str]:
|
||||||
|
"""
|
||||||
|
Предсказывает тему для заданного текста из известного набора тем.
|
||||||
|
Args:
|
||||||
|
text: Текст для классификации
|
||||||
|
Returns:
|
||||||
|
Tuple[float, str]: (уверенность, тема)
|
||||||
|
"""
|
||||||
|
if not self.is_ready():
|
||||||
|
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
|
||||||
|
return 0.0, "unknown"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ищем наиболее похожую тему
|
||||||
|
results = self.topic_embeddings.search(text, 1)
|
||||||
|
if not results:
|
||||||
|
return 0.0, "unknown"
|
||||||
|
|
||||||
|
score, topic = results[0]
|
||||||
|
return float(score), topic
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ошибка при определении темы: {str(e)}")
|
||||||
|
return 0.0, "unknown"
|
||||||
|
|
||||||
|
def search_similar(self, query: str, limit: int = 5) -> List[Dict[str, any]]:
|
||||||
|
"""
|
||||||
|
Ищет публикации похожие на поисковый запрос.
|
||||||
|
Args:
|
||||||
|
query: Поисковый запрос
|
||||||
|
limit: Максимальное количество результатов
|
||||||
|
Returns:
|
||||||
|
List[Dict]: Список найденных публикаций с оценкой релевантности
|
||||||
|
"""
|
||||||
|
if not self.is_ready():
|
||||||
|
logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ищем похожие публикации
|
||||||
|
results = self.search_embeddings.search(query, limit)
|
||||||
|
|
||||||
|
# Формируем результаты
|
||||||
|
found_publications = []
|
||||||
|
for score, pub_id in results:
|
||||||
|
# Находим публикацию по id
|
||||||
|
publication = next(
|
||||||
|
(pub for pub in self.publications if str(pub['id']) == pub_id),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if publication:
|
||||||
|
found_publications.append({
|
||||||
|
**publication,
|
||||||
|
'relevance': float(score)
|
||||||
|
})
|
||||||
|
|
||||||
|
return found_publications
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ошибка при поиске публикаций: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""
|
||||||
|
Проверяет, готовы ли векторные представления.
|
||||||
|
"""
|
||||||
|
return self.topic_embeddings is not None and self.search_embeddings is not None
|
||||||
|
|
||||||
|
def wait_until_ready(self) -> None:
|
||||||
|
"""
|
||||||
|
Ожидает завершения подготовки векторных представлений.
|
||||||
|
"""
|
||||||
|
if self._initialization_future:
|
||||||
|
self._initialization_future.result()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Очистка ресурсов при удалении объекта.
|
||||||
|
"""
|
||||||
|
if self._executor:
|
||||||
|
self._executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
# Пример использования:
|
||||||
|
"""
|
||||||
|
shouts_by_topic = {
|
||||||
|
"Спорт": "... большой текст со всеми спортивными публикациями ...",
|
||||||
|
"Технологии": "... большой текст со всеми технологическими публикациями ...",
|
||||||
|
"Политика": "... большой текст со всеми политическими публикациями ..."
|
||||||
|
}
|
||||||
|
|
||||||
|
publications = [
|
||||||
|
{
|
||||||
|
'id': 1,
|
||||||
|
'title': 'Новый процессор AMD',
|
||||||
|
'text': 'Компания AMD представила новый процессор...'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 2,
|
||||||
|
'title': 'Футбольный матч',
|
||||||
|
'text': 'Вчера состоялся решающий матч...'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Создание классификатора
|
||||||
|
classifier = TopicClassifier(shouts_by_topic, publications)
|
||||||
|
classifier.initialize()
|
||||||
|
classifier.wait_until_ready()
|
||||||
|
|
||||||
|
# Определение темы текста
|
||||||
|
text = "Новый процессор показал высокую производительность"
|
||||||
|
score, topic = classifier.predict_topic(text)
|
||||||
|
print(f"Тема: {topic} (уверенность: {score:.4f})")
|
||||||
|
|
||||||
|
# Поиск похожих публикаций
|
||||||
|
query = "процессор AMD производительность"
|
||||||
|
similar_publications = classifier.search_similar(query, limit=3)
|
||||||
|
for pub in similar_publications:
|
||||||
|
print(f"\nНайдена публикация (релевантность: {pub['relevance']:.4f}):")
|
||||||
|
print(f"Заголовок: {pub['title']}")
|
||||||
|
print(f"Текст: {pub['text'][:100]}...")
|
||||||
|
"""
|
||||||
|
|
|
@ -5,7 +5,7 @@ PORT = 8000
|
||||||
DB_URL = (
|
DB_URL = (
|
||||||
environ.get("DATABASE_URL", "").replace("postgres://", "postgresql://")
|
environ.get("DATABASE_URL", "").replace("postgres://", "postgresql://")
|
||||||
or environ.get("DB_URL", "").replace("postgres://", "postgresql://")
|
or environ.get("DB_URL", "").replace("postgres://", "postgresql://")
|
||||||
or "sqlite:///discoursio-db.sqlite3"
|
or "sqlite:///discoursio.db"
|
||||||
)
|
)
|
||||||
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1"
|
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1"
|
||||||
AUTH_URL = environ.get("AUTH_URL") or ""
|
AUTH_URL = environ.get("AUTH_URL") or ""
|
||||||
|
@ -15,3 +15,9 @@ MODE = "development" if "dev" in sys.argv else "production"
|
||||||
|
|
||||||
ADMIN_SECRET = environ.get("AUTH_SECRET") or "nothing"
|
ADMIN_SECRET = environ.get("AUTH_SECRET") or "nothing"
|
||||||
WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else"
|
WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else"
|
||||||
|
|
||||||
|
# own auth
|
||||||
|
ONETIME_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 3 # 3 days
|
||||||
|
SESSION_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 30 # 30 days
|
||||||
|
JWT_ALGORITHM = "HS256"
|
||||||
|
JWT_SECRET_KEY = environ.get("JWT_SECRET") or "nothing-else-jwt-secret-matters"
|
55
tests/conftest.py
Normal file
55
tests/conftest.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
from services.db import Base
|
||||||
|
from services.redis import redis
|
||||||
|
from settings import DB_URL
|
||||||
|
|
||||||
|
# Use SQLite for testing
|
||||||
|
TEST_DB_URL = "sqlite:///test.db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for the test session."""
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def test_engine():
|
||||||
|
"""Create a test database engine."""
|
||||||
|
engine = create_engine(TEST_DB_URL)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
yield engine
|
||||||
|
Base.metadata.drop_all(engine)
|
||||||
|
os.remove("test.db")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session(test_engine):
|
||||||
|
"""Create a new database session for a test."""
|
||||||
|
connection = test_engine.connect()
|
||||||
|
transaction = connection.begin()
|
||||||
|
session = Session(bind=connection)
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
session.close()
|
||||||
|
transaction.rollback()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def redis_client():
|
||||||
|
"""Create a test Redis client."""
|
||||||
|
await redis.connect()
|
||||||
|
yield redis
|
||||||
|
await redis.disconnect()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_client():
|
||||||
|
"""Create a TestClient instance."""
|
||||||
|
return TestClient(app)
|
95
tests/test_drafts.py
Normal file
95
tests/test_drafts.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import pytest
|
||||||
|
from orm.shout import Shout
|
||||||
|
from orm.author import Author
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_author(db_session):
|
||||||
|
"""Create a test author."""
|
||||||
|
author = Author(
|
||||||
|
name="Test Author",
|
||||||
|
slug="test-author",
|
||||||
|
user="test-user-id"
|
||||||
|
)
|
||||||
|
db_session.add(author)
|
||||||
|
db_session.commit()
|
||||||
|
return author
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_shout(db_session):
|
||||||
|
"""Create test shout with required fields."""
|
||||||
|
author = Author(name="Test Author", slug="test-author", user="test-user-id")
|
||||||
|
db_session.add(author)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
shout = Shout(
|
||||||
|
title="Test Shout",
|
||||||
|
slug="test-shout",
|
||||||
|
created_by=author.id, # Обязательное поле
|
||||||
|
body="Test body",
|
||||||
|
layout="article",
|
||||||
|
lang="ru"
|
||||||
|
)
|
||||||
|
db_session.add(shout)
|
||||||
|
db_session.commit()
|
||||||
|
return shout
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_shout(test_client, db_session, test_author):
|
||||||
|
"""Test creating a new shout."""
|
||||||
|
response = test_client.post(
|
||||||
|
"/",
|
||||||
|
json={
|
||||||
|
"query": """
|
||||||
|
mutation CreateDraft($input: DraftInput!) {
|
||||||
|
create_draft(input: $input) {
|
||||||
|
error
|
||||||
|
draft {
|
||||||
|
id
|
||||||
|
title
|
||||||
|
body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
"variables": {
|
||||||
|
"input": {
|
||||||
|
"title": "Test Shout",
|
||||||
|
"body": "This is a test shout",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "errors" not in data
|
||||||
|
assert data["data"]["create_draft"]["draft"]["title"] == "Test Shout"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_drafts(test_client, db_session):
|
||||||
|
"""Test retrieving a shout."""
|
||||||
|
response = test_client.post(
|
||||||
|
"/",
|
||||||
|
json={
|
||||||
|
"query": """
|
||||||
|
query {
|
||||||
|
load_drafts {
|
||||||
|
error
|
||||||
|
drafts {
|
||||||
|
id
|
||||||
|
title
|
||||||
|
body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
"variables": {
|
||||||
|
"slug": "test-shout"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "errors" not in data
|
||||||
|
assert data["data"]["load_drafts"]["drafts"] == []
|
64
tests/test_reactions.py
Normal file
64
tests/test_reactions.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import pytest
|
||||||
|
from orm.reaction import Reaction, ReactionKind
|
||||||
|
from orm.shout import Shout
|
||||||
|
from orm.author import Author
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_setup(db_session):
|
||||||
|
"""Set up test data."""
|
||||||
|
now = int(datetime.now().timestamp())
|
||||||
|
author = Author(name="Test Author", slug="test-author", user="test-user-id")
|
||||||
|
db_session.add(author)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
shout = Shout(
|
||||||
|
title="Test Shout",
|
||||||
|
slug="test-shout",
|
||||||
|
created_by=author.id,
|
||||||
|
body="This is a test shout",
|
||||||
|
layout="article",
|
||||||
|
lang="ru",
|
||||||
|
community=1,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
db_session.add_all([author, shout])
|
||||||
|
db_session.commit()
|
||||||
|
return {"author": author, "shout": shout}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_reaction(test_client, db_session, test_setup):
|
||||||
|
"""Test creating a reaction on a shout."""
|
||||||
|
response = test_client.post(
|
||||||
|
"/",
|
||||||
|
json={
|
||||||
|
"query": """
|
||||||
|
mutation CreateReaction($reaction: ReactionInput!) {
|
||||||
|
create_reaction(reaction: $reaction) {
|
||||||
|
error
|
||||||
|
reaction {
|
||||||
|
id
|
||||||
|
kind
|
||||||
|
body
|
||||||
|
created_by {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
"variables": {
|
||||||
|
"reaction": {
|
||||||
|
"shout": test_setup["shout"].id,
|
||||||
|
"kind": ReactionKind.LIKE.value,
|
||||||
|
"body": "Great post!"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "error" not in data
|
||||||
|
assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value
|
83
tests/test_shouts.py
Normal file
83
tests/test_shouts.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import pytest
|
||||||
|
from orm.author import Author
|
||||||
|
from orm.shout import Shout
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_shout(db_session):
|
||||||
|
"""Create test shout with required fields."""
|
||||||
|
now = int(datetime.now().timestamp())
|
||||||
|
author = Author(name="Test Author", slug="test-author", user="test-user-id")
|
||||||
|
db_session.add(author)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
now = int(datetime.now().timestamp())
|
||||||
|
|
||||||
|
shout = Shout(
|
||||||
|
title="Test Shout",
|
||||||
|
slug="test-shout",
|
||||||
|
created_by=author.id,
|
||||||
|
body="Test body",
|
||||||
|
layout="article",
|
||||||
|
lang="ru",
|
||||||
|
community=1,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
db_session.add(shout)
|
||||||
|
db_session.commit()
|
||||||
|
return shout
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_shout(test_client, db_session):
|
||||||
|
"""Test retrieving a shout."""
|
||||||
|
# Создаем автора
|
||||||
|
author = Author(name="Test Author", slug="test-author", user="test-user-id")
|
||||||
|
db_session.add(author)
|
||||||
|
db_session.flush()
|
||||||
|
now = int(datetime.now().timestamp())
|
||||||
|
|
||||||
|
# Создаем публикацию со всеми обязательными полями
|
||||||
|
shout = Shout(
|
||||||
|
title="Test Shout",
|
||||||
|
body="This is a test shout",
|
||||||
|
slug="test-shout",
|
||||||
|
created_by=author.id,
|
||||||
|
layout="article",
|
||||||
|
lang="ru",
|
||||||
|
community=1,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
db_session.add(shout)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = test_client.post(
|
||||||
|
"/",
|
||||||
|
json={
|
||||||
|
"query": """
|
||||||
|
query GetShout($slug: String!) {
|
||||||
|
get_shout(slug: $slug) {
|
||||||
|
id
|
||||||
|
title
|
||||||
|
body
|
||||||
|
created_at
|
||||||
|
updated_at
|
||||||
|
created_by {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
slug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
"variables": {
|
||||||
|
"slug": "test-shout"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "errors" not in data
|
||||||
|
assert data["data"]["get_shout"]["title"] == "Test Shout"
|
101
tests/test_validations.py
Normal file
101
tests/test_validations.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from auth.validations import (
|
||||||
|
AuthInput,
|
||||||
|
UserRegistrationInput,
|
||||||
|
UserLoginInput,
|
||||||
|
TokenPayload,
|
||||||
|
OAuthInput,
|
||||||
|
AuthResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
class TestAuthValidations:
|
||||||
|
def test_auth_input(self):
|
||||||
|
"""Test basic auth input validation"""
|
||||||
|
# Valid case
|
||||||
|
auth = AuthInput(
|
||||||
|
user_id="123",
|
||||||
|
username="testuser",
|
||||||
|
token="1234567890abcdef1234567890abcdef"
|
||||||
|
)
|
||||||
|
assert auth.user_id == "123"
|
||||||
|
assert auth.username == "testuser"
|
||||||
|
|
||||||
|
# Invalid cases
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthInput(user_id="", username="test", token="x" * 32)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthInput(user_id="123", username="t", token="x" * 32)
|
||||||
|
|
||||||
|
def test_user_registration(self):
|
||||||
|
"""Test user registration validation"""
|
||||||
|
# Valid case
|
||||||
|
user = UserRegistrationInput(
|
||||||
|
email="test@example.com",
|
||||||
|
password="SecurePass123!",
|
||||||
|
name="Test User"
|
||||||
|
)
|
||||||
|
assert user.email == "test@example.com"
|
||||||
|
assert user.name == "Test User"
|
||||||
|
|
||||||
|
# Test email validation
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
UserRegistrationInput(
|
||||||
|
email="invalid-email",
|
||||||
|
password="SecurePass123!",
|
||||||
|
name="Test"
|
||||||
|
)
|
||||||
|
assert "Invalid email format" in str(exc.value)
|
||||||
|
|
||||||
|
# Test password validation
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
UserRegistrationInput(
|
||||||
|
email="test@example.com",
|
||||||
|
password="weak",
|
||||||
|
name="Test"
|
||||||
|
)
|
||||||
|
assert "String should have at least 8 characters" in str(exc.value)
|
||||||
|
|
||||||
|
def test_token_payload(self):
|
||||||
|
"""Test token payload validation"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
exp = now + timedelta(hours=1)
|
||||||
|
|
||||||
|
payload = TokenPayload(
|
||||||
|
user_id="123",
|
||||||
|
username="testuser",
|
||||||
|
exp=exp,
|
||||||
|
iat=now
|
||||||
|
)
|
||||||
|
assert payload.user_id == "123"
|
||||||
|
assert payload.username == "testuser"
|
||||||
|
assert payload.scopes == [] # Default empty list
|
||||||
|
|
||||||
|
def test_auth_response(self):
|
||||||
|
"""Test auth response validation"""
|
||||||
|
# Success case
|
||||||
|
success_resp = AuthResponse(
|
||||||
|
success=True,
|
||||||
|
token="valid_token",
|
||||||
|
user={"id": "123", "name": "Test"}
|
||||||
|
)
|
||||||
|
assert success_resp.success is True
|
||||||
|
assert success_resp.token == "valid_token"
|
||||||
|
|
||||||
|
# Error case
|
||||||
|
error_resp = AuthResponse(
|
||||||
|
success=False,
|
||||||
|
error="Invalid credentials"
|
||||||
|
)
|
||||||
|
assert error_resp.success is False
|
||||||
|
assert error_resp.error == "Invalid credentials"
|
||||||
|
|
||||||
|
# Invalid case - отсутствует обязательное поле token при success=True
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthResponse(
|
||||||
|
success=True,
|
||||||
|
user={"id": "123", "name": "Test"}
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user