diff --git a/.gitignore b/.gitignore index dac81c4..ba134b3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ .vscode poetry.lock .venv +.ruff_cache diff --git a/resolvers/load.py b/resolvers/load.py index c10d5eb..4f76060 100644 --- a/resolvers/load.py +++ b/resolvers/load.py @@ -8,7 +8,7 @@ from resolvers.model import ( NotificationsResult, ) from orm.notification import NotificationSeen -from typing import Dict +from typing import Dict, List import time, json import strawberry from sqlalchemy.orm import aliased @@ -17,7 +17,7 @@ from sqlalchemy import select, and_ async def get_notifications_grouped( author_id: int, after: int = 0, limit: int = 10, offset: int = 0, mark_as_read=False -) -> Dict[str, NotificationGroup]: +): """ Retrieves notifications for a given author. @@ -53,8 +53,13 @@ async def get_notifications_grouped( notifications: Dict[str, NotificationGroup] = {} counter = 0 + unread = 0 + total = 0 with local_session() as session: - for n, seen in session.execute(query): + notifications_result = session.execute(query) + for n, seen in notifications_result: + total += 1 + unread += 1 if author_id in n.seen else 0 thread_id = "" payload = json.loads(n.payload) print(f"[resolvers.schema] {n.action} {n.entity}: {payload}") @@ -142,7 +147,7 @@ async def get_notifications_grouped( if counter > limit: break - return notifications + return notifications, unread, total @strawberry.type @@ -152,8 +157,6 @@ class Query: author_id = info.context.get("author_id") notification_groups: Dict[str, NotificationGroup] = {} if author_id: - # TODO: add total counter calculation - # TODO: add unread counter calculation - notification_groups = await get_notifications_grouped(author_id, after, limit, offset) + notification_groups, total, unread = await get_notifications_grouped(author_id, after, limit, offset) notifications = sorted(notification_groups.values(), key=lambda group: group.updated_at, reverse=True) return NotificationsResult(notifications=notifications, total=0, unread=0, error=None) diff --git a/resolvers/seen.py b/resolvers/seen.py index ff4aaca..4457245 100644 --- a/resolvers/seen.py +++ b/resolvers/seen.py @@ -1,9 +1,12 @@ +from sqlalchemy import and_ from orm.notification import NotificationSeen from services.db import local_session -from resolvers.model import NotificationSeenResult +from resolvers.model import Notification, NotificationSeenResult, NotificationReaction from resolvers.load import get_notifications_grouped +from typing import List import strawberry import logging +import json from sqlalchemy.exc import SQLAlchemyError @@ -14,7 +17,7 @@ logger = logging.getLogger(__name__) @strawberry.type class Mutation: @strawberry.mutation - async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult: + async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult: author_id = info.context.get("author_id") if author_id: with local_session() as session: @@ -31,15 +34,49 @@ class Mutation: return NotificationSeenResult(error=None) @strawberry.mutation - async def mark_all_notifications_as_read(self, info, limit: int = 10, offset: int = 0) -> NotificationSeenResult: + async def mark_seen_after(self, info, after: int) -> NotificationSeenResult: # TODO: use latest loaded notification_id as input offset parameter - ngroups = {} error = None try: author_id = info.context.get("author_id") if author_id: - ngroups = get_notifications_grouped(author_id, limit, offset, mark_as_read=True) + with local_session() as session: + nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all() + for n in nnn: + try: + ns = NotificationSeen(notification=n.id, viewer=author_id) + session.add(ns) + session.commit() + except SQLAlchemyError as e: + session.rollback() except Exception as e: print(e) error = "cant mark as read" return NotificationSeenResult(error=error) + + @strawberry.mutation + async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult: + error = None + author_id = info.context.get("author_id") + if author_id: + [shout_id, reply_to_id] = thread.split("::") + with local_session() as session: + # TODO: handle new follower and new shout notifications + new_reaction_notifications = session.query(Notification).filter(Notification.action == "create", Notification.entity == "reaction", Notification.created_at > after).all() + removed_reaction_notifications = session.query(Notification).filter(Notification.action == "delete", Notification.entity == "reaction", Notification.created_at > after).all() + exclude = set([]) + for nr in removed_reaction_notifications: + reaction: NotificationReaction = json.loads(nr.payload) + exclude.add(reaction.id) + for n in new_reaction_notifications: + reaction: NotificationReaction = json.loads(n.payload) + if reaction.id not in exclude and str(reaction.shout) == str(shout_id) and str(reaction.reply_to) == str(reply_to_id): + try: + ns = NotificationSeen(notification=n.id, viewer=author_id) + session.add(ns) + session.commit() + except Exception: + session.rollback() + else: + error = "You are not logged in" + return NotificationSeenResult(error=error) diff --git a/services/auth.py b/services/auth.py index 1950811..79f1823 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,5 +1,4 @@ from aiohttp import ClientSession -from starlette.exceptions import HTTPException from strawberry.extensions import Extension from settings import AUTH_URL