From 919aaa951fdfa28c467756c0778e926713193a80 Mon Sep 17 00:00:00 2001 From: Untone Date: Thu, 30 Nov 2023 11:40:27 +0300 Subject: [PATCH] string-enum-fix --- orm/invite.py | 4 ++-- orm/reaction.py | 2 +- orm/shout.py | 2 +- resolvers/collab.py | 8 +++++--- resolvers/community.py | 4 ++-- resolvers/editor.py | 18 ++++++++---------- resolvers/reader.py | 2 ++ 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/orm/invite.py b/orm/invite.py index 62dd4bb1..2b46b8fd 100644 --- a/orm/invite.py +++ b/orm/invite.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, ForeignKey, Enum +from sqlalchemy import Column, ForeignKey, Enum, String from sqlalchemy.orm import relationship from services.db import Base from orm.author import Author @@ -18,7 +18,7 @@ class Invite(Base): inviter_id = Column(ForeignKey("author.id"), nullable=False, index=True) author_id = Column(ForeignKey("author.id"), nullable=False, index=True) shout_id = Column(ForeignKey("shout.id"), nullable=False, index=True) - status = Column(Enum(InviteStatus), default=InviteStatus.PENDING.value) + status = Column(String, default=InviteStatus.PENDING.value) inviter = relationship(Author, foreign_keys=[inviter_id]) author = relationship(Author, foreign_keys=[author_id]) diff --git a/orm/reaction.py b/orm/reaction.py index cd9332b3..aff4ba5a 100644 --- a/orm/reaction.py +++ b/orm/reaction.py @@ -38,6 +38,6 @@ class Reaction(Base): quote = Column(String, nullable=True, comment="Original quoted text") shout = Column(ForeignKey("shout.id"), nullable=False, index=True) created_by = Column(ForeignKey("author.id"), nullable=False, index=True) - kind = Column(Enum(ReactionKind), nullable=False, index=True) + kind = Column(String, nullable=False, index=True) oid = Column(String) diff --git a/orm/shout.py b/orm/shout.py index 2ef8ddb3..968b7dbb 100644 --- a/orm/shout.py +++ b/orm/shout.py @@ -80,7 +80,7 @@ class Shout(Base): communities = relationship(lambda: Community, secondary="shout_community") reactions = relationship(lambda: Reaction) - visibility = Column(Enum(ShoutVisibility), default=ShoutVisibility.AUTHORS.value) + visibility = Column(String, default=ShoutVisibility.AUTHORS.value) lang = Column(String, nullable=False, default="ru", comment="Language") version_of = Column(ForeignKey("shout.id"), nullable=True) diff --git a/resolvers/collab.py b/resolvers/collab.py index e8a619fd..90aec563 100644 --- a/resolvers/collab.py +++ b/resolvers/collab.py @@ -21,9 +21,11 @@ async def accept_invite(_, info, invite_id: int): # Add the user to the shout authors shout = session.query(Shout).filter(Shout.id == invite.shout_id).first() if shout: - shout.authors.append(author) - session.delete(invite) - session.commit() + if author not in shout.authors: + shout.authors.append(author) + session.delete(invite) + session.add(shout) + session.commit() return {"success": True, "message": "Invite accepted"} else: return {"error": "Shout not found"} diff --git a/resolvers/community.py b/resolvers/community.py index 2036a1f3..57149e4b 100644 --- a/resolvers/community.py +++ b/resolvers/community.py @@ -113,5 +113,5 @@ async def get_community(_, _info, slug): q = select(Community).where(Community.slug == slug) q = add_community_stat_columns(q) - authors = get_communities_from_query(q) - return authors[0] + communities = get_communities_from_query(q) + return communities[0] diff --git a/resolvers/editor.py b/resolvers/editor.py index 732162a1..cfe46fbb 100644 --- a/resolvers/editor.py +++ b/resolvers/editor.py @@ -22,6 +22,7 @@ async def get_shouts_drafts(_, info): q = ( select(Shout) .options( + joinedload(Shout.created_by, Author.id == Shout.created_by), joinedload(Shout.authors), joinedload(Shout.topics), ) @@ -43,9 +44,6 @@ async def create_shout(_, info, inp): shout_dict = None if author: topics = session.query(Topic).filter(Topic.slug.in_(inp.get("topics", []))).all() - authors = inp.get("authors", []) - if author.id not in authors: - authors.insert(0, author.id) current_time = int(time.time()) new_shout = Shout( **{ @@ -55,7 +53,8 @@ async def create_shout(_, info, inp): "description": inp.get("description"), "body": inp.get("body", ""), "layout": inp.get("layout"), - "authors": authors, + "created_by": author.id, + "authors": [], "slug": inp.get("slug") or f"draft-{time.time()}", "topics": inp.get("topics"), "visibility": ShoutVisibility.AUTHORS.value, @@ -65,7 +64,7 @@ async def create_shout(_, info, inp): for topic in topics: t = ShoutTopic(topic=topic.id, shout=new_shout.id) session.add(t) - # NOTE: shout made by one first author + # NOTE: shout made by one author sa = ShoutAuthor(shout=new_shout.id, author=author.id) session.add(sa) shout_dict = new_shout.dict() @@ -89,6 +88,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False): shout = ( session.query(Shout) .options( + joinedload(Shout.created_by, Author.id == Shout.created_by), joinedload(Shout.authors), joinedload(Shout.topics), ) @@ -97,7 +97,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False): ) if not shout: return {"error": "shout not found"} - if shout.created_by != author.id: + if shout.created_by != author.id and author.id not in shout.authors: return {"error": "access denied"} if shout_input is not None: topics_input = shout_input["topics"] @@ -136,9 +136,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False): ) for shout_topic_to_remove in shout_topics_to_remove: session.delete(shout_topic_to_remove) - shout_input["mainTopic"] = shout_input["mainTopic"]["slug"] - if shout_input["mainTopic"] == "": - del shout_input["mainTopic"] + # Replace datetime with Unix timestamp shout_input["updated_at"] = current_time # Set updated_at as Unix timestamp Shout.update(shout, shout_input) @@ -168,7 +166,7 @@ async def delete_shout(_, info, shout_id): if not shout: return {"error": "invalid shout id"} if author: - if author.id not in shout.authors: + if shout.created_by != author.id and author.id not in shout.authors: return {"error": "access denied"} for author_id in shout.authors: reactions_unfollow(author_id, shout_id) diff --git a/resolvers/reader.py b/resolvers/reader.py index b77126dd..f35bfc8e 100644 --- a/resolvers/reader.py +++ b/resolvers/reader.py @@ -72,6 +72,7 @@ def apply_filters(q, filters, author_id=None): async def get_shout(_, _info, slug=None, shout_id=None): with local_session() as session: q = select(Shout).options( + joinedload(Shout.created_by), joinedload(Shout.authors), joinedload(Shout.topics), ) @@ -141,6 +142,7 @@ async def load_shouts_by(_, info, options): q = ( select(Shout) .options( + joinedload(Shout.created_by, Author.id == Shout.created_by), joinedload(Shout.authors), joinedload(Shout.topics), )