fmt2
This commit is contained in:
parent
7beddea5b1
commit
59a1f8c902
21
.pre-commit-config.yaml
Normal file
21
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,21 @@
|
|||
fail_fast: true
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
- id: detect-private-key
|
||||
- id: double-quote-string-fixer
|
||||
- id: check-ast
|
||||
- id: check-merge-conflict
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.13
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
12
main.py
12
main.py
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from os.path import exists
|
||||
|
||||
|
@ -13,11 +14,10 @@ from resolvers.listener import notifications_worker
|
|||
from resolvers.schema import schema
|
||||
from services.rediscache import redis
|
||||
from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN
|
||||
import logging
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("\t[main]\t")
|
||||
logger = logging.getLogger('\t[main]\t')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
|
@ -27,9 +27,9 @@ async def start_up():
|
|||
task = asyncio.create_task(notifications_worker())
|
||||
logger.info(task)
|
||||
|
||||
if MODE == "dev":
|
||||
if MODE == 'dev':
|
||||
if exists(DEV_SERVER_PID_FILE_NAME):
|
||||
with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f:
|
||||
with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f:
|
||||
f.write(str(os.getpid()))
|
||||
else:
|
||||
try:
|
||||
|
@ -46,7 +46,7 @@ async def start_up():
|
|||
],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("sentry init error", e)
|
||||
logger.error('sentry init error', e)
|
||||
|
||||
|
||||
async def shutdown():
|
||||
|
@ -54,4 +54,4 @@ async def shutdown():
|
|||
|
||||
|
||||
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
|
||||
app.mount("/", GraphQL(schema, debug=True))
|
||||
app.mount('/', GraphQL(schema, debug=True))
|
||||
|
|
|
@ -1,46 +1,45 @@
|
|||
import time
|
||||
|
||||
from sqlalchemy import JSON as JSONType
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from services.db import Base
|
||||
|
||||
|
||||
class AuthorRating(Base):
|
||||
__tablename__ = "author_rating"
|
||||
__tablename__ = 'author_rating'
|
||||
|
||||
id = None # type: ignore
|
||||
rater = Column(ForeignKey("author.id"), primary_key=True, index=True)
|
||||
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
|
||||
rater = Column(ForeignKey('author.id'), primary_key=True, index=True)
|
||||
author = Column(ForeignKey('author.id'), primary_key=True, index=True)
|
||||
plus = Column(Boolean)
|
||||
|
||||
|
||||
class AuthorFollower(Base):
|
||||
__tablename__ = "author_follower"
|
||||
__tablename__ = 'author_follower'
|
||||
|
||||
id = None # type: ignore
|
||||
follower = Column(ForeignKey("author.id"), primary_key=True, index=True)
|
||||
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
|
||||
follower = Column(ForeignKey('author.id'), primary_key=True, index=True)
|
||||
author = Column(ForeignKey('author.id'), primary_key=True, index=True)
|
||||
created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
|
||||
auto = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class Author(Base):
|
||||
__tablename__ = "author"
|
||||
__tablename__ = 'author'
|
||||
|
||||
user = Column(String, unique=True) # unbounded link with authorizer's User type
|
||||
|
||||
name = Column(String, nullable=True, comment="Display name")
|
||||
name = Column(String, nullable=True, comment='Display name')
|
||||
slug = Column(String, unique=True, comment="Author's slug")
|
||||
bio = Column(String, nullable=True, comment="Bio") # status description
|
||||
about = Column(String, nullable=True, comment="About") # long and formatted
|
||||
pic = Column(String, nullable=True, comment="Picture")
|
||||
links = Column(JSONType, nullable=True, comment="Links")
|
||||
bio = Column(String, nullable=True, comment='Bio') # status description
|
||||
about = Column(String, nullable=True, comment='About') # long and formatted
|
||||
pic = Column(String, nullable=True, comment='Picture')
|
||||
links = Column(JSON, nullable=True, comment='Links')
|
||||
|
||||
ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author)
|
||||
|
||||
created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
|
||||
last_seen = Column(Integer, nullable=False, default=lambda: int(time.time()))
|
||||
updated_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
|
||||
deleted_at = Column(Integer, nullable=True, comment="Deleted at")
|
||||
deleted_at = Column(Integer, nullable=True, comment='Deleted at')
|
||||
|
|
|
@ -1,42 +1,41 @@
|
|||
import time
|
||||
from enum import Enum as Enumeration
|
||||
|
||||
from sqlalchemy import JSON as JSONType, func, cast
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
|
||||
from sqlalchemy import JSON, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm.session import engine
|
||||
|
||||
from orm.author import Author
|
||||
from services.db import Base
|
||||
import time
|
||||
|
||||
|
||||
class NotificationEntity(Enumeration):
|
||||
REACTION = "reaction"
|
||||
SHOUT = "shout"
|
||||
FOLLOWER = "follower"
|
||||
REACTION = 'reaction'
|
||||
SHOUT = 'shout'
|
||||
FOLLOWER = 'follower'
|
||||
|
||||
|
||||
class NotificationAction(Enumeration):
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
SEEN = "seen"
|
||||
FOLLOW = "follow"
|
||||
UNFOLLOW = "unfollow"
|
||||
CREATE = 'create'
|
||||
UPDATE = 'update'
|
||||
DELETE = 'delete'
|
||||
SEEN = 'seen'
|
||||
FOLLOW = 'follow'
|
||||
UNFOLLOW = 'unfollow'
|
||||
|
||||
|
||||
class NotificationSeen(Base):
|
||||
__tablename__ = "notification_seen"
|
||||
__tablename__ = 'notification_seen'
|
||||
|
||||
viewer = Column(ForeignKey("author.id"))
|
||||
notification = Column(ForeignKey("notification.id"))
|
||||
viewer = Column(ForeignKey('author.id'))
|
||||
notification = Column(ForeignKey('notification.id'))
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notification"
|
||||
__tablename__ = 'notification'
|
||||
|
||||
created_at = Column(Integer, server_default=str(int(time.time())))
|
||||
entity = Column(String, nullable=False)
|
||||
action = Column(String, nullable=False)
|
||||
payload = Column(JSONType, nullable=True)
|
||||
payload = Column(JSON, nullable=True)
|
||||
|
||||
seen = relationship(lambda: Author, secondary="notification_seen")
|
||||
seen = relationship(lambda: Author, secondary='notification_seen')
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from orm.notification import Notification, NotificationAction, NotificationEntity
|
||||
from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout
|
||||
from services.db import local_session
|
||||
from services.rediscache import redis
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(f"[listener.listen_task] ")
|
||||
from orm.notification import Notification
|
||||
from resolvers.model import NotificationAuthor, NotificationReaction, NotificationShout
|
||||
from services.db import local_session
|
||||
from services.rediscache import redis
|
||||
|
||||
|
||||
logger = logging.getLogger('[listener.listen_task] ')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
|
@ -19,8 +21,8 @@ async def handle_notification(n: ServiceMessage, channel: str):
|
|||
"""создаеёт новое хранимое уведомление"""
|
||||
with local_session() as session:
|
||||
try:
|
||||
if channel.startswith("follower:"):
|
||||
author_id = int(channel.split(":")[1])
|
||||
if channel.startswith('follower:'):
|
||||
author_id = int(channel.split(':')[1])
|
||||
if isinstance(n.payload, NotificationAuthor):
|
||||
n.payload.following_id = author_id
|
||||
n = Notification(action=n.action, entity=n.entity, payload=n.payload)
|
||||
|
@ -28,7 +30,7 @@ async def handle_notification(n: ServiceMessage, channel: str):
|
|||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"[listener.handle_reaction] error: {str(e)}")
|
||||
logger.error(f'[listener.handle_reaction] error: {str(e)}')
|
||||
|
||||
|
||||
async def listen_task(pattern):
|
||||
|
@ -38,9 +40,9 @@ async def listen_task(pattern):
|
|||
notification_message = ServiceMessage(**message_data)
|
||||
await handle_notification(notification_message, str(channel))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing notification: {str(e)}")
|
||||
logger.error(f'Error processing notification: {str(e)}')
|
||||
|
||||
|
||||
async def notifications_worker():
|
||||
# Use asyncio.gather to run tasks concurrently
|
||||
await asyncio.gather(listen_task("follower:*"), listen_task("reaction"), listen_task("shout"))
|
||||
await asyncio.gather(listen_task('follower:*'), listen_task('reaction'), listen_task('shout'))
|
||||
|
|
|
@ -1,27 +1,36 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import strawberry
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.sql import not_
|
||||
from services.db import local_session
|
||||
|
||||
from orm.notification import (
|
||||
Notification,
|
||||
NotificationAction,
|
||||
NotificationEntity,
|
||||
NotificationSeen,
|
||||
)
|
||||
from resolvers.model import (
|
||||
NotificationReaction,
|
||||
NotificationGroup,
|
||||
NotificationShout,
|
||||
NotificationAuthor,
|
||||
NotificationGroup,
|
||||
NotificationReaction,
|
||||
NotificationShout,
|
||||
NotificationsResult,
|
||||
)
|
||||
from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification
|
||||
from typing import Dict, List
|
||||
import time, json
|
||||
import strawberry
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.sql.expression import or_
|
||||
from sqlalchemy import select, and_
|
||||
import logging
|
||||
from services.db import local_session
|
||||
|
||||
logger = logging.getLogger("[resolvers.schema] ")
|
||||
|
||||
logger = logging.getLogger('[resolvers.schema] ')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0):
|
||||
async def get_notifications_grouped( # noqa: C901
|
||||
author_id: int, after: int = 0, limit: int = 10, offset: int = 0
|
||||
):
|
||||
"""
|
||||
Retrieves notifications for a given author.
|
||||
|
||||
|
@ -47,10 +56,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
authors: List[NotificationAuthor], # List of authors involved in the thread.
|
||||
}
|
||||
"""
|
||||
NotificationSeenAlias = aliased(NotificationSeen)
|
||||
query = select(Notification, NotificationSeenAlias.viewer.label("seen")).outerjoin(
|
||||
seen_alias = aliased(NotificationSeen)
|
||||
query = select(Notification, seen_alias.viewer.label('seen')).outerjoin(
|
||||
NotificationSeen,
|
||||
and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id),
|
||||
and_(
|
||||
NotificationSeen.viewer == author_id,
|
||||
NotificationSeen.notification == Notification.id,
|
||||
),
|
||||
)
|
||||
if after:
|
||||
query = query.filter(Notification.created_at > after)
|
||||
|
@ -62,23 +74,36 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
notifications_by_thread: Dict[str, List[Notification]] = {}
|
||||
groups_by_thread: Dict[str, NotificationGroup] = {}
|
||||
with local_session() as session:
|
||||
total = session.query(Notification).filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)).count()
|
||||
unread = session.query(Notification).filter(
|
||||
total = (
|
||||
session.query(Notification)
|
||||
.filter(
|
||||
and_(
|
||||
Notification.action == NotificationAction.CREATE.value,
|
||||
Notification.created_at > after,
|
||||
not_(Notification.seen)
|
||||
)
|
||||
).count()
|
||||
)
|
||||
.count()
|
||||
)
|
||||
unread = (
|
||||
session.query(Notification)
|
||||
.filter(
|
||||
and_(
|
||||
Notification.action == NotificationAction.CREATE.value,
|
||||
Notification.created_at > after,
|
||||
not_(Notification.seen),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
notifications_result = session.execute(query)
|
||||
for n, seen in notifications_result:
|
||||
thread_id = ""
|
||||
for n, _seen in notifications_result:
|
||||
thread_id = ''
|
||||
payload = json.loads(n.payload)
|
||||
logger.debug(f"[resolvers.schema] {n.action} {n.entity}: {payload}")
|
||||
if n.entity == "shout" and n.action == "create":
|
||||
logger.debug(f'[resolvers.schema] {n.action} {n.entity}: {payload}')
|
||||
if n.entity == 'shout' and n.action == 'create':
|
||||
shout: NotificationShout = payload
|
||||
thread_id += f"{shout.id}"
|
||||
logger.debug(f"create shout: {shout}")
|
||||
thread_id += f'{shout.id}'
|
||||
logger.debug(f'create shout: {shout}')
|
||||
group = groups_by_thread.get(thread_id) or NotificationGroup(
|
||||
id=thread_id,
|
||||
entity=n.entity,
|
||||
|
@ -86,8 +111,8 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
authors=shout.authors,
|
||||
updated_at=shout.created_at,
|
||||
reactions=[],
|
||||
action="create",
|
||||
seen=author_id in n.seen
|
||||
action='create',
|
||||
seen=author_id in n.seen,
|
||||
)
|
||||
# store group in result
|
||||
groups_by_thread[thread_id] = group
|
||||
|
@ -99,11 +124,11 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value:
|
||||
reaction: NotificationReaction = payload
|
||||
shout: NotificationShout = reaction.shout
|
||||
thread_id += f"{reaction.shout}"
|
||||
if reaction.kind == "LIKE" or reaction.kind == "DISLIKE":
|
||||
thread_id += f'{reaction.shout}'
|
||||
if not bool(reaction.reply_to) and (reaction.kind == 'LIKE' or reaction.kind == 'DISLIKE'):
|
||||
# TODO: making published reaction vote announce
|
||||
pass
|
||||
elif reaction.kind == "COMMENT":
|
||||
elif reaction.kind == 'COMMENT':
|
||||
if reaction.reply_to:
|
||||
thread_id += f"{'::' + str(reaction.reply_to)}"
|
||||
group: NotificationGroup | None = groups_by_thread.get(thread_id)
|
||||
|
@ -128,8 +153,9 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
break
|
||||
else:
|
||||
# init notification group
|
||||
reactions = []
|
||||
reactions.append(reaction.id)
|
||||
reactions = [
|
||||
reaction.id,
|
||||
]
|
||||
group = NotificationGroup(
|
||||
id=thread_id,
|
||||
action=n.action,
|
||||
|
@ -140,7 +166,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
authors=[
|
||||
reaction.created_by,
|
||||
],
|
||||
seen=author_id in n.seen
|
||||
seen=author_id in n.seen,
|
||||
)
|
||||
# store group in result
|
||||
groups_by_thread[thread_id] = group
|
||||
|
@ -149,8 +175,8 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
notifications.append(n)
|
||||
notifications_by_thread[thread_id] = notifications
|
||||
|
||||
elif n.entity == "follower":
|
||||
thread_id = "followers"
|
||||
elif n.entity == 'follower':
|
||||
thread_id = 'followers'
|
||||
follower: NotificationAuthor = payload
|
||||
group = groups_by_thread.get(thread_id) or NotificationGroup(
|
||||
id=thread_id,
|
||||
|
@ -158,11 +184,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
updated_at=int(time.time()),
|
||||
shout=None,
|
||||
reactions=[],
|
||||
entity="follower",
|
||||
action="follow",
|
||||
seen=author_id in n.seen
|
||||
entity='follower',
|
||||
action='follow',
|
||||
seen=author_id in n.seen,
|
||||
)
|
||||
group.authors = [follower, ]
|
||||
group.authors = [
|
||||
follower,
|
||||
]
|
||||
group.updated_at = int(time.time())
|
||||
# store group in result
|
||||
groups_by_thread[thread_id] = group
|
||||
|
@ -182,7 +210,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
|
|||
class Query:
|
||||
@strawberry.field
|
||||
async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult:
|
||||
author_id = info.context.get("author_id")
|
||||
author_id = info.context.get('author_id')
|
||||
groups: Dict[str, NotificationGroup] = {}
|
||||
if author_id:
|
||||
groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import strawberry
|
||||
from typing import List, Optional
|
||||
|
||||
import strawberry
|
||||
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
|
||||
|
||||
from orm.notification import Notification as NotificationMessage
|
||||
|
||||
|
||||
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
|
||||
import strawberry
|
||||
from strawberry.schema.config import StrawberryConfig
|
||||
|
||||
from services.auth import LoginRequiredMiddleware
|
||||
from resolvers.load import Query
|
||||
from resolvers.seen import Mutation
|
||||
from services.auth import LoginRequiredMiddleware
|
||||
from services.db import Base, engine
|
||||
|
||||
|
||||
schema = strawberry.Schema(
|
||||
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
|
||||
)
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from sqlalchemy import and_
|
||||
from orm.notification import NotificationSeen
|
||||
from services.db import local_session
|
||||
from resolvers.model import Notification, NotificationSeenResult, NotificationReaction
|
||||
import json
|
||||
import logging
|
||||
|
||||
import strawberry
|
||||
import logging
|
||||
import json
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from orm.notification import NotificationSeen
|
||||
from resolvers.model import Notification, NotificationReaction, NotificationSeenResult
|
||||
from services.db import local_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||
class Mutation:
|
||||
@strawberry.mutation
|
||||
async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult:
|
||||
author_id = info.context.get("author_id")
|
||||
author_id = info.context.get('author_id')
|
||||
if author_id:
|
||||
with local_session() as session:
|
||||
try:
|
||||
|
@ -27,9 +27,9 @@ class Mutation:
|
|||
except SQLAlchemyError as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}"
|
||||
f'[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}'
|
||||
)
|
||||
return NotificationSeenResult(error="cant mark as read")
|
||||
return NotificationSeenResult(error='cant mark as read')
|
||||
return NotificationSeenResult(error=None)
|
||||
|
||||
@strawberry.mutation
|
||||
|
@ -37,7 +37,7 @@ class Mutation:
|
|||
# TODO: use latest loaded notification_id as input offset parameter
|
||||
error = None
|
||||
try:
|
||||
author_id = info.context.get("author_id")
|
||||
author_id = info.context.get('author_id')
|
||||
if author_id:
|
||||
with local_session() as session:
|
||||
nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all()
|
||||
|
@ -46,26 +46,26 @@ class Mutation:
|
|||
ns = NotificationSeen(notification=n.id, viewer=author_id)
|
||||
session.add(ns)
|
||||
session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
except SQLAlchemyError:
|
||||
session.rollback()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
error = "cant mark as read"
|
||||
error = 'cant mark as read'
|
||||
return NotificationSeenResult(error=error)
|
||||
|
||||
@strawberry.mutation
|
||||
async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult:
|
||||
error = None
|
||||
author_id = info.context.get("author_id")
|
||||
author_id = info.context.get('author_id')
|
||||
if author_id:
|
||||
[shout_id, reply_to_id] = thread.split("::")
|
||||
[shout_id, reply_to_id] = thread.split('::')
|
||||
with local_session() as session:
|
||||
# TODO: handle new follower and new shout notifications
|
||||
new_reaction_notifications = (
|
||||
session.query(Notification)
|
||||
.filter(
|
||||
Notification.action == "create",
|
||||
Notification.entity == "reaction",
|
||||
Notification.action == 'create',
|
||||
Notification.entity == 'reaction',
|
||||
Notification.created_at > after,
|
||||
)
|
||||
.all()
|
||||
|
@ -73,13 +73,13 @@ class Mutation:
|
|||
removed_reaction_notifications = (
|
||||
session.query(Notification)
|
||||
.filter(
|
||||
Notification.action == "delete",
|
||||
Notification.entity == "reaction",
|
||||
Notification.action == 'delete',
|
||||
Notification.entity == 'reaction',
|
||||
Notification.created_at > after,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
exclude = set([])
|
||||
exclude = set()
|
||||
for nr in removed_reaction_notifications:
|
||||
reaction: NotificationReaction = json.loads(nr.payload)
|
||||
exclude.add(reaction.id)
|
||||
|
@ -97,5 +97,5 @@ class Mutation:
|
|||
except Exception:
|
||||
session.rollback()
|
||||
else:
|
||||
error = "You are not logged in"
|
||||
error = 'You are not logged in'
|
||||
return NotificationSeenResult(error=error)
|
||||
|
|
|
@ -1,57 +1,60 @@
|
|||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from strawberry.extensions import Extension
|
||||
|
||||
from settings import AUTH_URL
|
||||
from services.db import local_session
|
||||
from orm.author import Author
|
||||
from services.db import local_session
|
||||
from settings import AUTH_URL
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("\t[services.auth]\t")
|
||||
logger = logging.getLogger('\t[services.auth]\t')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def check_auth(req) -> str | None:
|
||||
token = req.headers.get("Authorization")
|
||||
user_id = ""
|
||||
token = req.headers.get('Authorization')
|
||||
user_id = ''
|
||||
if token:
|
||||
query_name = "validate_jwt_token"
|
||||
operation = "ValidateToken"
|
||||
query_name = 'validate_jwt_token'
|
||||
operation = 'ValidateToken'
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
variables = {
|
||||
"params": {
|
||||
"token_type": "access_token",
|
||||
"token": token,
|
||||
'params': {
|
||||
'token_type': 'access_token',
|
||||
'token': token,
|
||||
}
|
||||
}
|
||||
|
||||
gql = {
|
||||
"query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}",
|
||||
"variables": variables,
|
||||
"operationName": operation,
|
||||
'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}',
|
||||
'variables': variables,
|
||||
'operationName': operation,
|
||||
}
|
||||
try:
|
||||
# Asynchronous HTTP request to the authentication server
|
||||
async with ClientSession() as session:
|
||||
async with session.post(AUTH_URL, json=gql, headers=headers) as response:
|
||||
print(f"[services.auth] HTTP Response {response.status} {await response.text()}")
|
||||
print(f'[services.auth] HTTP Response {response.status} {await response.text()}')
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
errors = data.get("errors")
|
||||
errors = data.get('errors')
|
||||
if errors:
|
||||
print(f"[services.auth] errors: {errors}")
|
||||
print(f'[services.auth] errors: {errors}')
|
||||
else:
|
||||
user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub")
|
||||
user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub')
|
||||
if user_id:
|
||||
print(f"[services.auth] got user_id: {user_id}")
|
||||
print(f'[services.auth] got user_id: {user_id}')
|
||||
return user_id
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# Handling and logging exceptions during authentication check
|
||||
print(f"[services.auth] Error {e}")
|
||||
print(f'[services.auth] Error {e}')
|
||||
|
||||
return None
|
||||
|
||||
|
@ -59,14 +62,14 @@ async def check_auth(req) -> str | None:
|
|||
class LoginRequiredMiddleware(Extension):
|
||||
async def on_request_start(self):
|
||||
context = self.execution_context.context
|
||||
req = context.get("request")
|
||||
req = context.get('request')
|
||||
user_id = await check_auth(req)
|
||||
if user_id:
|
||||
context["user_id"] = user_id.strip()
|
||||
context['user_id'] = user_id.strip()
|
||||
with local_session() as session:
|
||||
author = session.query(Author).filter(Author.user == user_id).first()
|
||||
if author:
|
||||
context["author_id"] = author.id
|
||||
context["user_id"] = user_id or None
|
||||
context['author_id'] = author.id
|
||||
context['user_id'] = user_id or None
|
||||
|
||||
self.execution_context.context = context
|
||||
|
|
|
@ -4,47 +4,49 @@ import aiohttp
|
|||
|
||||
from settings import API_BASE
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
|
||||
# TODO: rewrite to orm usage?
|
||||
|
||||
|
||||
async def _request_endpoint(query_name, body) -> Any:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(API_BASE, headers=headers, json=body) as response:
|
||||
print(f"[services.core] {query_name} HTTP Response {response.status} {await response.text()}")
|
||||
print(f'[services.core] {query_name} HTTP Response {response.status} {await response.text()}')
|
||||
if response.status == 200:
|
||||
r = await response.json()
|
||||
if r:
|
||||
return r.get("data", {}).get(query_name, {})
|
||||
return r.get('data', {}).get(query_name, {})
|
||||
return []
|
||||
|
||||
|
||||
async def get_followed_shouts(author_id: int):
|
||||
query_name = "load_shouts_followed"
|
||||
operation = "GetFollowedShouts"
|
||||
query_name = 'load_shouts_followed'
|
||||
operation = 'GetFollowedShouts'
|
||||
|
||||
query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{
|
||||
{query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }}
|
||||
}}"""
|
||||
|
||||
gql = {
|
||||
"query": query,
|
||||
"operationName": operation,
|
||||
"variables": {"author_id": author_id, "limit": 1000, "offset": 0}, # FIXME: too big limit
|
||||
'query': query,
|
||||
'operationName': operation,
|
||||
'variables': {'author_id': author_id, 'limit': 1000, 'offset': 0}, # FIXME: too big limit
|
||||
}
|
||||
|
||||
return await _request_endpoint(query_name, gql)
|
||||
|
||||
|
||||
async def get_shout(shout_id):
|
||||
query_name = "get_shout"
|
||||
operation = "GetShout"
|
||||
query_name = 'get_shout'
|
||||
operation = 'GetShout'
|
||||
|
||||
query = f"""query {operation}($slug: String, $shout_id: Int) {{
|
||||
{query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }}
|
||||
}}"""
|
||||
|
||||
gql = {"query": query, "operationName": operation, "variables": {"slug": None, "shout_id": shout_id}}
|
||||
gql = {'query': query, 'operationName': operation, 'variables': {'slug': None, 'shout_id': shout_id}}
|
||||
|
||||
return await _request_endpoint(query_name, gql)
|
||||
|
|
|
@ -9,15 +9,16 @@ from sqlalchemy.sql.schema import Table
|
|||
|
||||
from settings import DB_URL
|
||||
|
||||
|
||||
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
|
||||
|
||||
T = TypeVar("T")
|
||||
T = TypeVar('T')
|
||||
|
||||
REGISTRY: Dict[str, type] = {}
|
||||
|
||||
|
||||
# @contextmanager
|
||||
def local_session(src=""):
|
||||
def local_session(src=''):
|
||||
return Session(bind=engine, expire_on_commit=False)
|
||||
|
||||
# try:
|
||||
|
@ -45,7 +46,7 @@ class Base(declarative_base()):
|
|||
__init__: Callable
|
||||
__allow_unmapped__ = True
|
||||
__abstract__ = True
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
|
@ -54,12 +55,12 @@ class Base(declarative_base()):
|
|||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
column_names = self.__table__.columns.keys()
|
||||
if "_sa_instance_state" in column_names:
|
||||
column_names.remove("_sa_instance_state")
|
||||
if '_sa_instance_state' in column_names:
|
||||
column_names.remove('_sa_instance_state')
|
||||
try:
|
||||
return {c: getattr(self, c) for c in column_names}
|
||||
except Exception as e:
|
||||
print(f"[services.db] Error dict: {e}")
|
||||
print(f'[services.db] Error dict: {e}')
|
||||
return {}
|
||||
|
||||
def update(self, values: Dict[str, Any]) -> None:
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import json
|
||||
|
||||
import redis.asyncio as aredis
|
||||
import asyncio
|
||||
from settings import REDIS_URL
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("\t[services.redis]\t")
|
||||
import redis.asyncio as aredis
|
||||
|
||||
from settings import REDIS_URL
|
||||
|
||||
|
||||
logger = logging.getLogger('\t[services.redis]\t')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self, uri=REDIS_URL):
|
||||
self._uri: str = uri
|
||||
|
@ -25,11 +27,11 @@ class RedisCache:
|
|||
async def execute(self, command, *args, **kwargs):
|
||||
if self._client:
|
||||
try:
|
||||
logger.debug(command + " " + " ".join(args))
|
||||
logger.debug(command + ' ' + ' '.join(args))
|
||||
r = await self._client.execute_command(command, *args, **kwargs)
|
||||
return r
|
||||
except Exception as e:
|
||||
logger.error(f"{e}")
|
||||
logger.error(f'{e}')
|
||||
return None
|
||||
|
||||
async def subscribe(self, *channels):
|
||||
|
@ -59,15 +61,15 @@ class RedisCache:
|
|||
|
||||
while True:
|
||||
message = await pubsub.get_message()
|
||||
if message and isinstance(message["data"], (str, bytes, bytearray)):
|
||||
logger.debug("pubsub got msg")
|
||||
if message and isinstance(message['data'], (str, bytes, bytearray)):
|
||||
logger.debug('pubsub got msg')
|
||||
try:
|
||||
yield json.loads(message["data"]), message.get("channel")
|
||||
yield json.loads(message['data']), message.get('channel')
|
||||
except Exception as e:
|
||||
logger.error(f"{e}")
|
||||
logger.error(f'{e}')
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
redis = RedisCache()
|
||||
|
||||
__all__ = ["redis"]
|
||||
__all__ = ['redis']
|
||||
|
|
17
settings.py
17
settings.py
|
@ -1,13 +1,14 @@
|
|||
from os import environ
|
||||
|
||||
|
||||
PORT = 80
|
||||
DB_URL = (
|
||||
environ.get("DATABASE_URL", environ.get("DB_URL", "")).replace("postgres://", "postgresql://")
|
||||
or "postgresql://postgres@localhost:5432/discoursio"
|
||||
environ.get('DATABASE_URL', environ.get('DB_URL', '')).replace('postgres://', 'postgresql://')
|
||||
or 'postgresql://postgres@localhost:5432/discoursio'
|
||||
)
|
||||
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1"
|
||||
API_BASE = environ.get("API_BASE") or "https://core.discours.io"
|
||||
AUTH_URL = environ.get("AUTH_URL") or "https://auth.discours.io"
|
||||
MODE = environ.get("MODE") or "production"
|
||||
SENTRY_DSN = environ.get("SENTRY_DSN")
|
||||
DEV_SERVER_PID_FILE_NAME = "dev-server.pid"
|
||||
REDIS_URL = environ.get('REDIS_URL') or 'redis://127.0.0.1'
|
||||
API_BASE = environ.get('API_BASE') or 'https://core.discours.io'
|
||||
AUTH_URL = environ.get('AUTH_URL') or 'https://auth.discours.io'
|
||||
MODE = environ.get('MODE') or 'production'
|
||||
SENTRY_DSN = environ.get('SENTRY_DSN')
|
||||
DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'
|
||||
|
|
Loading…
Reference in New Issue
Block a user