diff --git a/resolvers/schema.py b/resolvers/schema.py index 92ab26f..6bcd660 100644 --- a/resolvers/schema.py +++ b/resolvers/schema.py @@ -1,6 +1,5 @@ -from typing import List, Any - -from sqlalchemy import and_ +from typing import List +from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.exc import SQLAlchemyError from orm.author import Author @@ -35,6 +34,35 @@ class NotificationsResult: total: int +def get_notifications(author, session, limit, offset) -> List[Notification]: + NotificationSeenAlias = aliased(NotificationSeen) + query = select( + NotificationMessage, + NotificationSeenAlias.viewer.label("seen") + ).outerjoin( + NotificationSeen, + and_(NotificationSeen.viewer == author.id, NotificationSeen.notification == NotificationMessage.id), + ).group_by(NotificationSeen.notification) + if limit: + query = query.limit(limit) + if offset: + query = query.offset(offset) + + notifications = [] + for n, seen in session.execute(query): + ntf = Notification( + id=n.id, + payload=n.payload, + entity=n.entity, + action=n.action, + created_at=n.created_at, + seen=seen, + ) + if ntf: + notifications.append(ntf) + return notifications + + @strawberry.type class Query: @strawberry.field @@ -44,40 +72,21 @@ class Query: with local_session() as session: try: author = session.query(Author).filter(Author.user == user_id).first() - NotificationSeenAlias = aliased(NotificationSeen) - if author: - nnn = session.query( - NotificationMessage, - NotificationSeenAlias.viewer.label("seen") - ).outerjoin( - NotificationSeen, - and_(NotificationSeen.viewer == author.id, NotificationSeen.notification == NotificationMessage.id), - ).group_by(NotificationSeen.notification).limit(limit).offset(offset).all() - notifications = [] - for n, seen in nnn: - ntf = Notification( - id=n.id, - payload=n.payload, - entity=n.entity, - action=n.action, - created_at=n.created_at, - seen=seen, + notifications = get_notifications(author, session, limit, offset) + if notifications and len(notifications) > 0: + nr = NotificationsResult( + notifications=notifications, + unread=sum(1 for n in notifications if author.id in n.seen), + total=session.query(NotificationMessage).count() ) - if ntf: - notifications.append(ntf) - nr = NotificationsResult( - notifications = notifications, - unread = sum(1 for n in notifications if author.id in n.seen), - total = session.query(NotificationMessage).count() - ) - return nr + return nr except Exception as ex: print(f"[resolvers.schema] {ex}") return NotificationsResult( notifications=[], - total = 0, - unread = 0 + total=0, + unread=0 ) @@ -90,9 +99,10 @@ class Mutation: with local_session() as session: try: author = session.query(Author).filter(Author.user == user_id).first() - ns = NotificationSeen({"notification": notification_id, "viewer": author.id}) - session.add(ns) - session.commit() + if author: + ns = NotificationSeen(notification=notification_id, viewer=author.id) + session.add(ns) + session.commit() except SQLAlchemyError as e: session.rollback() print(f"[mark_notification_as_read] error: {str(e)}") @@ -109,7 +119,12 @@ class Mutation: try: author = session.query(Author).filter(Author.user == user_id).first() if author: - _nslist = session.query(NotificationSeen).filter(NotificationSeen.viewer == author.id).all() + nslist = get_notifications(author, session, None, None) + for n in nslist: + if author.id not in n.seen: + ns = NotificationSeen(viewer=author.id, notification=n.id) + session.add(ns) + session.commit() except SQLAlchemyError as e: session.rollback() print(f"[mark_all_notifications_as_read] error: {str(e)}")