diff --git a/auth/authenticate.py b/auth/authenticate.py index bf09ddbf..847730e5 100644 --- a/auth/authenticate.py +++ b/auth/authenticate.py @@ -71,6 +71,9 @@ class JWTAuthenticate(AuthenticationBackend): return AuthCredentials(scopes=[]), AuthUser(user_id=None) user = await UserStorage.get_user(payload.user_id) + if not user: + return AuthCredentials(scopes=[]), AuthUser(user_id=None) + scopes = await user.get_permission() return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), user diff --git a/orm/__init__.py b/orm/__init__.py index da7965d2..1a14abbd 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -14,6 +14,8 @@ __all__ = ["User", "Role", "Community", "Operation", "Permission", "Message", "S Base.metadata.create_all(engine) Operation.init_table() Resource.init_table() +Community.init_table() +Role.init_table() with local_session() as session: ShoutRatingStorage.init(session) diff --git a/orm/base.py b/orm/base.py index 0904d1f6..d653b6ce 100644 --- a/orm/base.py +++ b/orm/base.py @@ -33,7 +33,7 @@ class Base(declarative_base()): @classmethod def create(cls: Generic[T], **kwargs) -> Generic[T]: instance = cls(**kwargs) - return instance.save() + return instance.save(session) def save(self) -> Generic[T]: with local_session() as session: diff --git a/orm/community.py b/orm/community.py index a7efce3a..8f45ee51 100644 --- a/orm/community.py +++ b/orm/community.py @@ -1,15 +1,31 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, ForeignKey, DateTime from sqlalchemy.orm import relationship, backref -from orm.base import Base +from orm.base import Base, local_session class Community(Base): __tablename__ = 'community' - # id is auto number + name: str = Column(String, nullable=False, comment="Name") slug: str = Column(String, unique = True, nullable = False) desc: str = Column(String, nullable=False, default='') pic: str = Column(String, nullable=False, default='') createdAt: str = Column(DateTime, nullable=False, default = datetime.now, comment="Created at") createdBy: str = Column(ForeignKey("user.id"), nullable=False, comment="Creator") + + @staticmethod + def init_table(): + with local_session() as session: + default = session.query(Community).filter(Community.slug == "default").first() + if default: + Community.default_community = default + return + + default = Community.create( + name = "default", + slug = "default", + createdBy = 0 + ) + + Community.default_community = default diff --git a/orm/rbac.py b/orm/rbac.py index 24600113..b72bfe37 100644 --- a/orm/rbac.py +++ b/orm/rbac.py @@ -7,6 +7,7 @@ from sqlalchemy import String, Integer, Column, ForeignKey, UniqueConstraint, Ty from sqlalchemy.orm import relationship, selectinload from orm.base import Base, REGISTRY, engine, local_session +from orm.community import Community class ClassType(TypeDecorator): @@ -31,13 +32,27 @@ class ClassType(TypeDecorator): class Role(Base): __tablename__ = 'role' - # id is auto field - name: str = Column(String, nullable=False, comment="Role Name") desc: str = Column(String, nullable=True, comment="Role Description") community: int = Column(ForeignKey("community.id", ondelete="CASCADE"), nullable=False, comment="Community") permissions = relationship(lambda: Permission) + @staticmethod + def init_table(): + with local_session() as session: + default = session.query(Role).filter(Role.name == "author").first() + if default: + Role.default_role = default + return + + default = Role.create( + name = "author", + desc = "Role for author", + community = Community.default_community.id + ) + + Role.default_role = default + class Operation(Base): __tablename__ = 'operation' name: str = Column(String, nullable=False, unique=True, comment="Operation Name") diff --git a/orm/user.py b/orm/user.py index 50d45372..33e4401a 100644 --- a/orm/user.py +++ b/orm/user.py @@ -94,7 +94,7 @@ class UserStorage: async def add_user(user): self = UserStorage async with self.lock: - self.users[id] = user + self.users[user.id] = user @staticmethod async def del_user(id): diff --git a/resolvers/auth.py b/resolvers/auth.py index e6b8142e..3d44f48b 100644 --- a/resolvers/auth.py +++ b/resolvers/auth.py @@ -8,7 +8,7 @@ from auth.authorize import Authorize from auth.identity import Identity from auth.password import Password from auth.email import send_confirm_email, send_auth_email -from orm import User, UserStorage +from orm import User, UserStorage, Role, UserRole from orm.base import local_session from resolvers.base import mutation, query from exceptions import InvalidPassword @@ -37,16 +37,21 @@ async def register(*_, email: str, password: str = ""): username = email.split('@')[0] user_dict["username"] = username user_dict["slug"] = quote_plus(translit(username, 'ru', reversed=True).replace('.', '-').lower()) + if password: + user_dict["password"] = Password.encode(password) + user = User(**user_dict) + user.roles.append(Role.default_role) + with local_session() as session: + session.add(user) + session.commit() + + await UserStorage.add_user(user) + if not password: - user = User.create(**user_dict) await send_confirm_email(user) - await UserStorage.add_user(user) return { "user": user } - user_dict["password"] = Password.encode(password) - user = User.create(**user_dict) token = await Authorize.authorize(user) - await UserStorage.add_user(user) return {"user": user, "token": token } diff --git a/resolvers/profile.py b/resolvers/profile.py index e0bede31..dcfabad6 100644 --- a/resolvers/profile.py +++ b/resolvers/profile.py @@ -16,7 +16,7 @@ async def get_user_by_slug(_, info, slug): group_by(User.id).\ first() user = row.User - user.rating = row.rating + user["rating"] = row.rating return { "user": user } # TODO: remove some fields for public @query.field("getCurrentUser") @@ -46,3 +46,16 @@ async def user_roles(_, info): where(UserRole.user_id == user_id).all() return roles + +@mutation.field("updateProfile") +@login_required +async def update_profile(_, info, profile): + auth = info.context["request"].auth + user_id = auth.user_id + + with local_session() as session: + user = session.query(User).filter(User.id == user_id).first() + user.update(profile) + session.commit() + + return {} diff --git a/schema.graphql b/schema.graphql index 5a78640d..f1aae417 100644 --- a/schema.graphql +++ b/schema.graphql @@ -36,9 +36,10 @@ input ShoutInput { } input ProfileInput { - email: String - username: String + name: String userpic: String + links: [String] + bio: String } input CommunityInput {