string-enum-fix

This commit is contained in:
Untone 2023-11-30 11:40:27 +03:00
parent 1362eaa125
commit 919aaa951f
7 changed files with 21 additions and 19 deletions

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, ForeignKey, Enum from sqlalchemy import Column, ForeignKey, Enum, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from services.db import Base from services.db import Base
from orm.author import Author from orm.author import Author
@ -18,7 +18,7 @@ class Invite(Base):
inviter_id = Column(ForeignKey("author.id"), nullable=False, index=True) inviter_id = Column(ForeignKey("author.id"), nullable=False, index=True)
author_id = Column(ForeignKey("author.id"), nullable=False, index=True) author_id = Column(ForeignKey("author.id"), nullable=False, index=True)
shout_id = Column(ForeignKey("shout.id"), nullable=False, index=True) shout_id = Column(ForeignKey("shout.id"), nullable=False, index=True)
status = Column(Enum(InviteStatus), default=InviteStatus.PENDING.value) status = Column(String, default=InviteStatus.PENDING.value)
inviter = relationship(Author, foreign_keys=[inviter_id]) inviter = relationship(Author, foreign_keys=[inviter_id])
author = relationship(Author, foreign_keys=[author_id]) author = relationship(Author, foreign_keys=[author_id])

View File

@ -38,6 +38,6 @@ class Reaction(Base):
quote = Column(String, nullable=True, comment="Original quoted text") quote = Column(String, nullable=True, comment="Original quoted text")
shout = Column(ForeignKey("shout.id"), nullable=False, index=True) shout = Column(ForeignKey("shout.id"), nullable=False, index=True)
created_by = Column(ForeignKey("author.id"), nullable=False, index=True) created_by = Column(ForeignKey("author.id"), nullable=False, index=True)
kind = Column(Enum(ReactionKind), nullable=False, index=True) kind = Column(String, nullable=False, index=True)
oid = Column(String) oid = Column(String)

View File

@ -80,7 +80,7 @@ class Shout(Base):
communities = relationship(lambda: Community, secondary="shout_community") communities = relationship(lambda: Community, secondary="shout_community")
reactions = relationship(lambda: Reaction) reactions = relationship(lambda: Reaction)
visibility = Column(Enum(ShoutVisibility), default=ShoutVisibility.AUTHORS.value) visibility = Column(String, default=ShoutVisibility.AUTHORS.value)
lang = Column(String, nullable=False, default="ru", comment="Language") lang = Column(String, nullable=False, default="ru", comment="Language")
version_of = Column(ForeignKey("shout.id"), nullable=True) version_of = Column(ForeignKey("shout.id"), nullable=True)

View File

@ -21,9 +21,11 @@ async def accept_invite(_, info, invite_id: int):
# Add the user to the shout authors # Add the user to the shout authors
shout = session.query(Shout).filter(Shout.id == invite.shout_id).first() shout = session.query(Shout).filter(Shout.id == invite.shout_id).first()
if shout: if shout:
shout.authors.append(author) if author not in shout.authors:
session.delete(invite) shout.authors.append(author)
session.commit() session.delete(invite)
session.add(shout)
session.commit()
return {"success": True, "message": "Invite accepted"} return {"success": True, "message": "Invite accepted"}
else: else:
return {"error": "Shout not found"} return {"error": "Shout not found"}

View File

@ -113,5 +113,5 @@ async def get_community(_, _info, slug):
q = select(Community).where(Community.slug == slug) q = select(Community).where(Community.slug == slug)
q = add_community_stat_columns(q) q = add_community_stat_columns(q)
authors = get_communities_from_query(q) communities = get_communities_from_query(q)
return authors[0] return communities[0]

View File

@ -22,6 +22,7 @@ async def get_shouts_drafts(_, info):
q = ( q = (
select(Shout) select(Shout)
.options( .options(
joinedload(Shout.created_by, Author.id == Shout.created_by),
joinedload(Shout.authors), joinedload(Shout.authors),
joinedload(Shout.topics), joinedload(Shout.topics),
) )
@ -43,9 +44,6 @@ async def create_shout(_, info, inp):
shout_dict = None shout_dict = None
if author: if author:
topics = session.query(Topic).filter(Topic.slug.in_(inp.get("topics", []))).all() topics = session.query(Topic).filter(Topic.slug.in_(inp.get("topics", []))).all()
authors = inp.get("authors", [])
if author.id not in authors:
authors.insert(0, author.id)
current_time = int(time.time()) current_time = int(time.time())
new_shout = Shout( new_shout = Shout(
**{ **{
@ -55,7 +53,8 @@ async def create_shout(_, info, inp):
"description": inp.get("description"), "description": inp.get("description"),
"body": inp.get("body", ""), "body": inp.get("body", ""),
"layout": inp.get("layout"), "layout": inp.get("layout"),
"authors": authors, "created_by": author.id,
"authors": [],
"slug": inp.get("slug") or f"draft-{time.time()}", "slug": inp.get("slug") or f"draft-{time.time()}",
"topics": inp.get("topics"), "topics": inp.get("topics"),
"visibility": ShoutVisibility.AUTHORS.value, "visibility": ShoutVisibility.AUTHORS.value,
@ -65,7 +64,7 @@ async def create_shout(_, info, inp):
for topic in topics: for topic in topics:
t = ShoutTopic(topic=topic.id, shout=new_shout.id) t = ShoutTopic(topic=topic.id, shout=new_shout.id)
session.add(t) session.add(t)
# NOTE: shout made by one first author # NOTE: shout made by one author
sa = ShoutAuthor(shout=new_shout.id, author=author.id) sa = ShoutAuthor(shout=new_shout.id, author=author.id)
session.add(sa) session.add(sa)
shout_dict = new_shout.dict() shout_dict = new_shout.dict()
@ -89,6 +88,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False):
shout = ( shout = (
session.query(Shout) session.query(Shout)
.options( .options(
joinedload(Shout.created_by, Author.id == Shout.created_by),
joinedload(Shout.authors), joinedload(Shout.authors),
joinedload(Shout.topics), joinedload(Shout.topics),
) )
@ -97,7 +97,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False):
) )
if not shout: if not shout:
return {"error": "shout not found"} return {"error": "shout not found"}
if shout.created_by != author.id: if shout.created_by != author.id and author.id not in shout.authors:
return {"error": "access denied"} return {"error": "access denied"}
if shout_input is not None: if shout_input is not None:
topics_input = shout_input["topics"] topics_input = shout_input["topics"]
@ -136,9 +136,7 @@ async def update_shout(_, info, shout_id, shout_input=None, publish=False):
) )
for shout_topic_to_remove in shout_topics_to_remove: for shout_topic_to_remove in shout_topics_to_remove:
session.delete(shout_topic_to_remove) session.delete(shout_topic_to_remove)
shout_input["mainTopic"] = shout_input["mainTopic"]["slug"]
if shout_input["mainTopic"] == "":
del shout_input["mainTopic"]
# Replace datetime with Unix timestamp # Replace datetime with Unix timestamp
shout_input["updated_at"] = current_time # Set updated_at as Unix timestamp shout_input["updated_at"] = current_time # Set updated_at as Unix timestamp
Shout.update(shout, shout_input) Shout.update(shout, shout_input)
@ -168,7 +166,7 @@ async def delete_shout(_, info, shout_id):
if not shout: if not shout:
return {"error": "invalid shout id"} return {"error": "invalid shout id"}
if author: if author:
if author.id not in shout.authors: if shout.created_by != author.id and author.id not in shout.authors:
return {"error": "access denied"} return {"error": "access denied"}
for author_id in shout.authors: for author_id in shout.authors:
reactions_unfollow(author_id, shout_id) reactions_unfollow(author_id, shout_id)

View File

@ -72,6 +72,7 @@ def apply_filters(q, filters, author_id=None):
async def get_shout(_, _info, slug=None, shout_id=None): async def get_shout(_, _info, slug=None, shout_id=None):
with local_session() as session: with local_session() as session:
q = select(Shout).options( q = select(Shout).options(
joinedload(Shout.created_by),
joinedload(Shout.authors), joinedload(Shout.authors),
joinedload(Shout.topics), joinedload(Shout.topics),
) )
@ -141,6 +142,7 @@ async def load_shouts_by(_, info, options):
q = ( q = (
select(Shout) select(Shout)
.options( .options(
joinedload(Shout.created_by, Author.id == Shout.created_by),
joinedload(Shout.authors), joinedload(Shout.authors),
joinedload(Shout.topics), joinedload(Shout.topics),
) )