This commit is contained in:
Untone 2024-01-26 03:40:49 +03:00
parent 7beddea5b1
commit 59a1f8c902
14 changed files with 249 additions and 188 deletions

21
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,21 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- id: detect-private-key
- id: double-quote-string-fixer
- id: check-ast
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

12
main.py
View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import os import os
from os.path import exists from os.path import exists
@ -13,11 +14,10 @@ from resolvers.listener import notifications_worker
from resolvers.schema import schema from resolvers.schema import schema
from services.rediscache import redis from services.rediscache import redis
from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN
import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("\t[main]\t") logger = logging.getLogger('\t[main]\t')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -27,9 +27,9 @@ async def start_up():
task = asyncio.create_task(notifications_worker()) task = asyncio.create_task(notifications_worker())
logger.info(task) logger.info(task)
if MODE == "dev": if MODE == 'dev':
if exists(DEV_SERVER_PID_FILE_NAME): if exists(DEV_SERVER_PID_FILE_NAME):
with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f: with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f:
f.write(str(os.getpid())) f.write(str(os.getpid()))
else: else:
try: try:
@ -46,7 +46,7 @@ async def start_up():
], ],
) )
except Exception as e: except Exception as e:
logger.error("sentry init error", e) logger.error('sentry init error', e)
async def shutdown(): async def shutdown():
@ -54,4 +54,4 @@ async def shutdown():
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
app.mount("/", GraphQL(schema, debug=True)) app.mount('/', GraphQL(schema, debug=True))

View File

@ -1,46 +1,45 @@
import time import time
from sqlalchemy import JSON as JSONType from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from services.db import Base from services.db import Base
class AuthorRating(Base): class AuthorRating(Base):
__tablename__ = "author_rating" __tablename__ = 'author_rating'
id = None # type: ignore id = None # type: ignore
rater = Column(ForeignKey("author.id"), primary_key=True, index=True) rater = Column(ForeignKey('author.id'), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True) author = Column(ForeignKey('author.id'), primary_key=True, index=True)
plus = Column(Boolean) plus = Column(Boolean)
class AuthorFollower(Base): class AuthorFollower(Base):
__tablename__ = "author_follower" __tablename__ = 'author_follower'
id = None # type: ignore id = None # type: ignore
follower = Column(ForeignKey("author.id"), primary_key=True, index=True) follower = Column(ForeignKey('author.id'), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True) author = Column(ForeignKey('author.id'), primary_key=True, index=True)
created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
auto = Column(Boolean, nullable=False, default=False) auto = Column(Boolean, nullable=False, default=False)
class Author(Base): class Author(Base):
__tablename__ = "author" __tablename__ = 'author'
user = Column(String, unique=True) # unbounded link with authorizer's User type user = Column(String, unique=True) # unbounded link with authorizer's User type
name = Column(String, nullable=True, comment="Display name") name = Column(String, nullable=True, comment='Display name')
slug = Column(String, unique=True, comment="Author's slug") slug = Column(String, unique=True, comment="Author's slug")
bio = Column(String, nullable=True, comment="Bio") # status description bio = Column(String, nullable=True, comment='Bio') # status description
about = Column(String, nullable=True, comment="About") # long and formatted about = Column(String, nullable=True, comment='About') # long and formatted
pic = Column(String, nullable=True, comment="Picture") pic = Column(String, nullable=True, comment='Picture')
links = Column(JSONType, nullable=True, comment="Links") links = Column(JSON, nullable=True, comment='Links')
ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author) ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author)
created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
last_seen = Column(Integer, nullable=False, default=lambda: int(time.time())) last_seen = Column(Integer, nullable=False, default=lambda: int(time.time()))
updated_at = Column(Integer, nullable=False, default=lambda: int(time.time())) updated_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
deleted_at = Column(Integer, nullable=True, comment="Deleted at") deleted_at = Column(Integer, nullable=True, comment='Deleted at')

View File

@ -1,42 +1,41 @@
import time
from enum import Enum as Enumeration from enum import Enum as Enumeration
from sqlalchemy import JSON as JSONType, func, cast from sqlalchemy import JSON, Column, ForeignKey, Integer, String
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import engine
from orm.author import Author from orm.author import Author
from services.db import Base from services.db import Base
import time
class NotificationEntity(Enumeration): class NotificationEntity(Enumeration):
REACTION = "reaction" REACTION = 'reaction'
SHOUT = "shout" SHOUT = 'shout'
FOLLOWER = "follower" FOLLOWER = 'follower'
class NotificationAction(Enumeration): class NotificationAction(Enumeration):
CREATE = "create" CREATE = 'create'
UPDATE = "update" UPDATE = 'update'
DELETE = "delete" DELETE = 'delete'
SEEN = "seen" SEEN = 'seen'
FOLLOW = "follow" FOLLOW = 'follow'
UNFOLLOW = "unfollow" UNFOLLOW = 'unfollow'
class NotificationSeen(Base): class NotificationSeen(Base):
__tablename__ = "notification_seen" __tablename__ = 'notification_seen'
viewer = Column(ForeignKey("author.id")) viewer = Column(ForeignKey('author.id'))
notification = Column(ForeignKey("notification.id")) notification = Column(ForeignKey('notification.id'))
class Notification(Base): class Notification(Base):
__tablename__ = "notification" __tablename__ = 'notification'
created_at = Column(Integer, server_default=str(int(time.time()))) created_at = Column(Integer, server_default=str(int(time.time())))
entity = Column(String, nullable=False) entity = Column(String, nullable=False)
action = Column(String, nullable=False) action = Column(String, nullable=False)
payload = Column(JSONType, nullable=True) payload = Column(JSON, nullable=True)
seen = relationship(lambda: Author, secondary="notification_seen") seen = relationship(lambda: Author, secondary='notification_seen')

View File

@ -1,11 +1,13 @@
from orm.notification import Notification, NotificationAction, NotificationEntity
from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout
from services.db import local_session
from services.rediscache import redis
import asyncio import asyncio
import logging import logging
logger = logging.getLogger(f"[listener.listen_task] ") from orm.notification import Notification
from resolvers.model import NotificationAuthor, NotificationReaction, NotificationShout
from services.db import local_session
from services.rediscache import redis
logger = logging.getLogger('[listener.listen_task] ')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -19,8 +21,8 @@ async def handle_notification(n: ServiceMessage, channel: str):
"""создаеёт новое хранимое уведомление""" """создаеёт новое хранимое уведомление"""
with local_session() as session: with local_session() as session:
try: try:
if channel.startswith("follower:"): if channel.startswith('follower:'):
author_id = int(channel.split(":")[1]) author_id = int(channel.split(':')[1])
if isinstance(n.payload, NotificationAuthor): if isinstance(n.payload, NotificationAuthor):
n.payload.following_id = author_id n.payload.following_id = author_id
n = Notification(action=n.action, entity=n.entity, payload=n.payload) n = Notification(action=n.action, entity=n.entity, payload=n.payload)
@ -28,7 +30,7 @@ async def handle_notification(n: ServiceMessage, channel: str):
session.commit() session.commit()
except Exception as e: except Exception as e:
session.rollback() session.rollback()
logger.error(f"[listener.handle_reaction] error: {str(e)}") logger.error(f'[listener.handle_reaction] error: {str(e)}')
async def listen_task(pattern): async def listen_task(pattern):
@ -38,9 +40,9 @@ async def listen_task(pattern):
notification_message = ServiceMessage(**message_data) notification_message = ServiceMessage(**message_data)
await handle_notification(notification_message, str(channel)) await handle_notification(notification_message, str(channel))
except Exception as e: except Exception as e:
logger.error(f"Error processing notification: {str(e)}") logger.error(f'Error processing notification: {str(e)}')
async def notifications_worker(): async def notifications_worker():
# Use asyncio.gather to run tasks concurrently # Use asyncio.gather to run tasks concurrently
await asyncio.gather(listen_task("follower:*"), listen_task("reaction"), listen_task("shout")) await asyncio.gather(listen_task('follower:*'), listen_task('reaction'), listen_task('shout'))

View File

@ -1,27 +1,36 @@
import json
import logging
import time
from typing import Dict, List
import strawberry
from sqlalchemy import and_, select
from sqlalchemy.orm import aliased
from sqlalchemy.sql import not_ from sqlalchemy.sql import not_
from services.db import local_session
from orm.notification import (
Notification,
NotificationAction,
NotificationEntity,
NotificationSeen,
)
from resolvers.model import ( from resolvers.model import (
NotificationReaction,
NotificationGroup,
NotificationShout,
NotificationAuthor, NotificationAuthor,
NotificationGroup,
NotificationReaction,
NotificationShout,
NotificationsResult, NotificationsResult,
) )
from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification from services.db import local_session
from typing import Dict, List
import time, json
import strawberry
from sqlalchemy.orm import aliased
from sqlalchemy.sql.expression import or_
from sqlalchemy import select, and_
import logging
logger = logging.getLogger("[resolvers.schema] ")
logger = logging.getLogger('[resolvers.schema] ')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0): async def get_notifications_grouped( # noqa: C901
author_id: int, after: int = 0, limit: int = 10, offset: int = 0
):
""" """
Retrieves notifications for a given author. Retrieves notifications for a given author.
@ -47,10 +56,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
authors: List[NotificationAuthor], # List of authors involved in the thread. authors: List[NotificationAuthor], # List of authors involved in the thread.
} }
""" """
NotificationSeenAlias = aliased(NotificationSeen) seen_alias = aliased(NotificationSeen)
query = select(Notification, NotificationSeenAlias.viewer.label("seen")).outerjoin( query = select(Notification, seen_alias.viewer.label('seen')).outerjoin(
NotificationSeen, NotificationSeen,
and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id), and_(
NotificationSeen.viewer == author_id,
NotificationSeen.notification == Notification.id,
),
) )
if after: if after:
query = query.filter(Notification.created_at > after) query = query.filter(Notification.created_at > after)
@ -62,23 +74,36 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
notifications_by_thread: Dict[str, List[Notification]] = {} notifications_by_thread: Dict[str, List[Notification]] = {}
groups_by_thread: Dict[str, NotificationGroup] = {} groups_by_thread: Dict[str, NotificationGroup] = {}
with local_session() as session: with local_session() as session:
total = session.query(Notification).filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)).count() total = (
unread = session.query(Notification).filter( session.query(Notification)
and_( .filter(
Notification.action == NotificationAction.CREATE.value, and_(
Notification.created_at > after, Notification.action == NotificationAction.CREATE.value,
not_(Notification.seen) Notification.created_at > after,
)
) )
).count() .count()
)
unread = (
session.query(Notification)
.filter(
and_(
Notification.action == NotificationAction.CREATE.value,
Notification.created_at > after,
not_(Notification.seen),
)
)
.count()
)
notifications_result = session.execute(query) notifications_result = session.execute(query)
for n, seen in notifications_result: for n, _seen in notifications_result:
thread_id = "" thread_id = ''
payload = json.loads(n.payload) payload = json.loads(n.payload)
logger.debug(f"[resolvers.schema] {n.action} {n.entity}: {payload}") logger.debug(f'[resolvers.schema] {n.action} {n.entity}: {payload}')
if n.entity == "shout" and n.action == "create": if n.entity == 'shout' and n.action == 'create':
shout: NotificationShout = payload shout: NotificationShout = payload
thread_id += f"{shout.id}" thread_id += f'{shout.id}'
logger.debug(f"create shout: {shout}") logger.debug(f'create shout: {shout}')
group = groups_by_thread.get(thread_id) or NotificationGroup( group = groups_by_thread.get(thread_id) or NotificationGroup(
id=thread_id, id=thread_id,
entity=n.entity, entity=n.entity,
@ -86,8 +111,8 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
authors=shout.authors, authors=shout.authors,
updated_at=shout.created_at, updated_at=shout.created_at,
reactions=[], reactions=[],
action="create", action='create',
seen=author_id in n.seen seen=author_id in n.seen,
) )
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group
@ -99,11 +124,11 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value: elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value:
reaction: NotificationReaction = payload reaction: NotificationReaction = payload
shout: NotificationShout = reaction.shout shout: NotificationShout = reaction.shout
thread_id += f"{reaction.shout}" thread_id += f'{reaction.shout}'
if reaction.kind == "LIKE" or reaction.kind == "DISLIKE": if not bool(reaction.reply_to) and (reaction.kind == 'LIKE' or reaction.kind == 'DISLIKE'):
# TODO: making published reaction vote announce # TODO: making published reaction vote announce
pass pass
elif reaction.kind == "COMMENT": elif reaction.kind == 'COMMENT':
if reaction.reply_to: if reaction.reply_to:
thread_id += f"{'::' + str(reaction.reply_to)}" thread_id += f"{'::' + str(reaction.reply_to)}"
group: NotificationGroup | None = groups_by_thread.get(thread_id) group: NotificationGroup | None = groups_by_thread.get(thread_id)
@ -128,8 +153,9 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
break break
else: else:
# init notification group # init notification group
reactions = [] reactions = [
reactions.append(reaction.id) reaction.id,
]
group = NotificationGroup( group = NotificationGroup(
id=thread_id, id=thread_id,
action=n.action, action=n.action,
@ -140,7 +166,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
authors=[ authors=[
reaction.created_by, reaction.created_by,
], ],
seen=author_id in n.seen seen=author_id in n.seen,
) )
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group
@ -149,20 +175,22 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
notifications.append(n) notifications.append(n)
notifications_by_thread[thread_id] = notifications notifications_by_thread[thread_id] = notifications
elif n.entity == "follower": elif n.entity == 'follower':
thread_id = "followers" thread_id = 'followers'
follower: NotificationAuthor = payload follower: NotificationAuthor = payload
group = groups_by_thread.get(thread_id) or NotificationGroup( group = groups_by_thread.get(thread_id) or NotificationGroup(
id=thread_id, id=thread_id,
authors=[follower], authors=[follower],
updated_at=int(time.time()), updated_at=int(time.time()),
shout=None, shout=None,
reactions=[], reactions=[],
entity="follower", entity='follower',
action="follow", action='follow',
seen=author_id in n.seen seen=author_id in n.seen,
) )
group.authors = [follower, ] group.authors = [
follower,
]
group.updated_at = int(time.time()) group.updated_at = int(time.time())
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group
@ -182,7 +210,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
class Query: class Query:
@strawberry.field @strawberry.field
async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult: async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get("author_id") author_id = info.context.get('author_id')
groups: Dict[str, NotificationGroup] = {} groups: Dict[str, NotificationGroup] = {}
if author_id: if author_id:
groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset) groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset)

View File

@ -1,8 +1,11 @@
import strawberry
from typing import List, Optional from typing import List, Optional
import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from orm.notification import Notification as NotificationMessage from orm.notification import Notification as NotificationMessage
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()

View File

@ -1,12 +1,12 @@
import strawberry import strawberry
from strawberry.schema.config import StrawberryConfig from strawberry.schema.config import StrawberryConfig
from services.auth import LoginRequiredMiddleware
from resolvers.load import Query from resolvers.load import Query
from resolvers.seen import Mutation from resolvers.seen import Mutation
from services.auth import LoginRequiredMiddleware
from services.db import Base, engine from services.db import Base, engine
schema = strawberry.Schema( schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
) )

View File

@ -1,14 +1,14 @@
from sqlalchemy import and_ import json
from orm.notification import NotificationSeen import logging
from services.db import local_session
from resolvers.model import Notification, NotificationSeenResult, NotificationReaction
import strawberry import strawberry
import logging from sqlalchemy import and_
import json
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from orm.notification import NotificationSeen
from resolvers.model import Notification, NotificationReaction, NotificationSeenResult
from services.db import local_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Mutation: class Mutation:
@strawberry.mutation @strawberry.mutation
async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult: async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult:
author_id = info.context.get("author_id") author_id = info.context.get('author_id')
if author_id: if author_id:
with local_session() as session: with local_session() as session:
try: try:
@ -27,9 +27,9 @@ class Mutation:
except SQLAlchemyError as e: except SQLAlchemyError as e:
session.rollback() session.rollback()
logger.error( logger.error(
f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}" f'[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}'
) )
return NotificationSeenResult(error="cant mark as read") return NotificationSeenResult(error='cant mark as read')
return NotificationSeenResult(error=None) return NotificationSeenResult(error=None)
@strawberry.mutation @strawberry.mutation
@ -37,7 +37,7 @@ class Mutation:
# TODO: use latest loaded notification_id as input offset parameter # TODO: use latest loaded notification_id as input offset parameter
error = None error = None
try: try:
author_id = info.context.get("author_id") author_id = info.context.get('author_id')
if author_id: if author_id:
with local_session() as session: with local_session() as session:
nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all() nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all()
@ -46,26 +46,26 @@ class Mutation:
ns = NotificationSeen(notification=n.id, viewer=author_id) ns = NotificationSeen(notification=n.id, viewer=author_id)
session.add(ns) session.add(ns)
session.commit() session.commit()
except SQLAlchemyError as e: except SQLAlchemyError:
session.rollback() session.rollback()
except Exception as e: except Exception as e:
print(e) print(e)
error = "cant mark as read" error = 'cant mark as read'
return NotificationSeenResult(error=error) return NotificationSeenResult(error=error)
@strawberry.mutation @strawberry.mutation
async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult: async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult:
error = None error = None
author_id = info.context.get("author_id") author_id = info.context.get('author_id')
if author_id: if author_id:
[shout_id, reply_to_id] = thread.split("::") [shout_id, reply_to_id] = thread.split('::')
with local_session() as session: with local_session() as session:
# TODO: handle new follower and new shout notifications # TODO: handle new follower and new shout notifications
new_reaction_notifications = ( new_reaction_notifications = (
session.query(Notification) session.query(Notification)
.filter( .filter(
Notification.action == "create", Notification.action == 'create',
Notification.entity == "reaction", Notification.entity == 'reaction',
Notification.created_at > after, Notification.created_at > after,
) )
.all() .all()
@ -73,13 +73,13 @@ class Mutation:
removed_reaction_notifications = ( removed_reaction_notifications = (
session.query(Notification) session.query(Notification)
.filter( .filter(
Notification.action == "delete", Notification.action == 'delete',
Notification.entity == "reaction", Notification.entity == 'reaction',
Notification.created_at > after, Notification.created_at > after,
) )
.all() .all()
) )
exclude = set([]) exclude = set()
for nr in removed_reaction_notifications: for nr in removed_reaction_notifications:
reaction: NotificationReaction = json.loads(nr.payload) reaction: NotificationReaction = json.loads(nr.payload)
exclude.add(reaction.id) exclude.add(reaction.id)
@ -97,5 +97,5 @@ class Mutation:
except Exception: except Exception:
session.rollback() session.rollback()
else: else:
error = "You are not logged in" error = 'You are not logged in'
return NotificationSeenResult(error=error) return NotificationSeenResult(error=error)

View File

@ -1,57 +1,60 @@
import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from strawberry.extensions import Extension from strawberry.extensions import Extension
from settings import AUTH_URL
from services.db import local_session
from orm.author import Author from orm.author import Author
from services.db import local_session
from settings import AUTH_URL
import logging
logger = logging.getLogger("\t[services.auth]\t") logger = logging.getLogger('\t[services.auth]\t')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
async def check_auth(req) -> str | None: async def check_auth(req) -> str | None:
token = req.headers.get("Authorization") token = req.headers.get('Authorization')
user_id = "" user_id = ''
if token: if token:
query_name = "validate_jwt_token" query_name = 'validate_jwt_token'
operation = "ValidateToken" operation = 'ValidateToken'
headers = { headers = {
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
variables = { variables = {
"params": { 'params': {
"token_type": "access_token", 'token_type': 'access_token',
"token": token, 'token': token,
} }
} }
gql = { gql = {
"query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", 'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}',
"variables": variables, 'variables': variables,
"operationName": operation, 'operationName': operation,
} }
try: try:
# Asynchronous HTTP request to the authentication server # Asynchronous HTTP request to the authentication server
async with ClientSession() as session: async with ClientSession() as session:
async with session.post(AUTH_URL, json=gql, headers=headers) as response: async with session.post(AUTH_URL, json=gql, headers=headers) as response:
print(f"[services.auth] HTTP Response {response.status} {await response.text()}") print(f'[services.auth] HTTP Response {response.status} {await response.text()}')
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
errors = data.get("errors") errors = data.get('errors')
if errors: if errors:
print(f"[services.auth] errors: {errors}") print(f'[services.auth] errors: {errors}')
else: else:
user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub')
if user_id: if user_id:
print(f"[services.auth] got user_id: {user_id}") print(f'[services.auth] got user_id: {user_id}')
return user_id return user_id
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# Handling and logging exceptions during authentication check # Handling and logging exceptions during authentication check
print(f"[services.auth] Error {e}") print(f'[services.auth] Error {e}')
return None return None
@ -59,14 +62,14 @@ async def check_auth(req) -> str | None:
class LoginRequiredMiddleware(Extension): class LoginRequiredMiddleware(Extension):
async def on_request_start(self): async def on_request_start(self):
context = self.execution_context.context context = self.execution_context.context
req = context.get("request") req = context.get('request')
user_id = await check_auth(req) user_id = await check_auth(req)
if user_id: if user_id:
context["user_id"] = user_id.strip() context['user_id'] = user_id.strip()
with local_session() as session: with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first() author = session.query(Author).filter(Author.user == user_id).first()
if author: if author:
context["author_id"] = author.id context['author_id'] = author.id
context["user_id"] = user_id or None context['user_id'] = user_id or None
self.execution_context.context = context self.execution_context.context = context

View File

@ -4,47 +4,49 @@ import aiohttp
from settings import API_BASE from settings import API_BASE
headers = {"Content-Type": "application/json"}
headers = {'Content-Type': 'application/json'}
# TODO: rewrite to orm usage? # TODO: rewrite to orm usage?
async def _request_endpoint(query_name, body) -> Any: async def _request_endpoint(query_name, body) -> Any:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(API_BASE, headers=headers, json=body) as response: async with session.post(API_BASE, headers=headers, json=body) as response:
print(f"[services.core] {query_name} HTTP Response {response.status} {await response.text()}") print(f'[services.core] {query_name} HTTP Response {response.status} {await response.text()}')
if response.status == 200: if response.status == 200:
r = await response.json() r = await response.json()
if r: if r:
return r.get("data", {}).get(query_name, {}) return r.get('data', {}).get(query_name, {})
return [] return []
async def get_followed_shouts(author_id: int): async def get_followed_shouts(author_id: int):
query_name = "load_shouts_followed" query_name = 'load_shouts_followed'
operation = "GetFollowedShouts" operation = 'GetFollowedShouts'
query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{ query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{
{query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }} {query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }}
}}""" }}"""
gql = { gql = {
"query": query, 'query': query,
"operationName": operation, 'operationName': operation,
"variables": {"author_id": author_id, "limit": 1000, "offset": 0}, # FIXME: too big limit 'variables': {'author_id': author_id, 'limit': 1000, 'offset': 0}, # FIXME: too big limit
} }
return await _request_endpoint(query_name, gql) return await _request_endpoint(query_name, gql)
async def get_shout(shout_id): async def get_shout(shout_id):
query_name = "get_shout" query_name = 'get_shout'
operation = "GetShout" operation = 'GetShout'
query = f"""query {operation}($slug: String, $shout_id: Int) {{ query = f"""query {operation}($slug: String, $shout_id: Int) {{
{query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }} {query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }}
}}""" }}"""
gql = {"query": query, "operationName": operation, "variables": {"slug": None, "shout_id": shout_id}} gql = {'query': query, 'operationName': operation, 'variables': {'slug': None, 'shout_id': shout_id}}
return await _request_endpoint(query_name, gql) return await _request_endpoint(query_name, gql)

View File

@ -9,15 +9,16 @@ from sqlalchemy.sql.schema import Table
from settings import DB_URL from settings import DB_URL
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20) engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
T = TypeVar("T") T = TypeVar('T')
REGISTRY: Dict[str, type] = {} REGISTRY: Dict[str, type] = {}
# @contextmanager # @contextmanager
def local_session(src=""): def local_session(src=''):
return Session(bind=engine, expire_on_commit=False) return Session(bind=engine, expire_on_commit=False)
# try: # try:
@ -45,7 +46,7 @@ class Base(declarative_base()):
__init__: Callable __init__: Callable
__allow_unmapped__ = True __allow_unmapped__ = True
__abstract__ = True __abstract__ = True
__table_args__ = {"extend_existing": True} __table_args__ = {'extend_existing': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@ -54,12 +55,12 @@ class Base(declarative_base()):
def dict(self) -> Dict[str, Any]: def dict(self) -> Dict[str, Any]:
column_names = self.__table__.columns.keys() column_names = self.__table__.columns.keys()
if "_sa_instance_state" in column_names: if '_sa_instance_state' in column_names:
column_names.remove("_sa_instance_state") column_names.remove('_sa_instance_state')
try: try:
return {c: getattr(self, c) for c in column_names} return {c: getattr(self, c) for c in column_names}
except Exception as e: except Exception as e:
print(f"[services.db] Error dict: {e}") print(f'[services.db] Error dict: {e}')
return {} return {}
def update(self, values: Dict[str, Any]) -> None: def update(self, values: Dict[str, Any]) -> None:

View File

@ -1,14 +1,16 @@
import json
import redis.asyncio as aredis
import asyncio import asyncio
from settings import REDIS_URL import json
import logging import logging
logger = logging.getLogger("\t[services.redis]\t") import redis.asyncio as aredis
from settings import REDIS_URL
logger = logging.getLogger('\t[services.redis]\t')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
class RedisCache: class RedisCache:
def __init__(self, uri=REDIS_URL): def __init__(self, uri=REDIS_URL):
self._uri: str = uri self._uri: str = uri
@ -25,11 +27,11 @@ class RedisCache:
async def execute(self, command, *args, **kwargs): async def execute(self, command, *args, **kwargs):
if self._client: if self._client:
try: try:
logger.debug(command + " " + " ".join(args)) logger.debug(command + ' ' + ' '.join(args))
r = await self._client.execute_command(command, *args, **kwargs) r = await self._client.execute_command(command, *args, **kwargs)
return r return r
except Exception as e: except Exception as e:
logger.error(f"{e}") logger.error(f'{e}')
return None return None
async def subscribe(self, *channels): async def subscribe(self, *channels):
@ -59,15 +61,15 @@ class RedisCache:
while True: while True:
message = await pubsub.get_message() message = await pubsub.get_message()
if message and isinstance(message["data"], (str, bytes, bytearray)): if message and isinstance(message['data'], (str, bytes, bytearray)):
logger.debug("pubsub got msg") logger.debug('pubsub got msg')
try: try:
yield json.loads(message["data"]), message.get("channel") yield json.loads(message['data']), message.get('channel')
except Exception as e: except Exception as e:
logger.error(f"{e}") logger.error(f'{e}')
await asyncio.sleep(1) await asyncio.sleep(1)
redis = RedisCache() redis = RedisCache()
__all__ = ["redis"] __all__ = ['redis']

View File

@ -1,13 +1,14 @@
from os import environ from os import environ
PORT = 80 PORT = 80
DB_URL = ( DB_URL = (
environ.get("DATABASE_URL", environ.get("DB_URL", "")).replace("postgres://", "postgresql://") environ.get('DATABASE_URL', environ.get('DB_URL', '')).replace('postgres://', 'postgresql://')
or "postgresql://postgres@localhost:5432/discoursio" or 'postgresql://postgres@localhost:5432/discoursio'
) )
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1" REDIS_URL = environ.get('REDIS_URL') or 'redis://127.0.0.1'
API_BASE = environ.get("API_BASE") or "https://core.discours.io" API_BASE = environ.get('API_BASE') or 'https://core.discours.io'
AUTH_URL = environ.get("AUTH_URL") or "https://auth.discours.io" AUTH_URL = environ.get('AUTH_URL') or 'https://auth.discours.io'
MODE = environ.get("MODE") or "production" MODE = environ.get('MODE') or 'production'
SENTRY_DSN = environ.get("SENTRY_DSN") SENTRY_DSN = environ.get('SENTRY_DSN')
DEV_SERVER_PID_FILE_NAME = "dev-server.pid" DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'