diff --git a/auth/authenticate.py b/auth/authenticate.py index 2637d302..60f9d6ef 100644 --- a/auth/authenticate.py +++ b/auth/authenticate.py @@ -2,7 +2,7 @@ from functools import wraps from typing import Optional, Tuple from graphql import GraphQLResolveInfo -import jwt +from jwt import DecodeError, ExpiredSignatureError from starlette.authentication import AuthenticationBackend from starlette.requests import HTTPConnection @@ -29,14 +29,14 @@ class _Authenticate: """ try: payload = Token.decode(token) - except exceptions.ExpiredSignatureError: + except ExpiredSignatureError: payload = Token.decode(token, verify_exp=False) if not await cls.exists(payload.user_id, token): raise InvalidToken("Login expired, please login again") if payload.device == "mobile": # noqa "we cat set mobile token to be valid forever" return payload - except exceptions.JWTDecodeError as e: + except DecodeError as e: raise InvalidToken("token format error") from e else: if not await cls.exists(payload.user_id, token): @@ -73,5 +73,4 @@ def login_required(func): if not auth.logged_in: raise OperationNotAllowed(auth.error_message or "Please login") return await func(parent, info, *args, **kwargs) - return wrap diff --git a/auth/identity.py b/auth/identity.py index 1eefe773..cfb4093a 100644 --- a/auth/identity.py +++ b/auth/identity.py @@ -4,6 +4,8 @@ from orm import User as OrmUser from orm.base import global_session from auth.validations import User +from sqlalchemy import or_ + class Identity: @staticmethod @@ -12,14 +14,22 @@ class Identity: if not user: raise ObjectNotExist("User does not exist") user = User(**user.dict()) + if user.password is None: + raise InvalidPassword("Wrong user password") if not Password.verify(password, user.password): raise InvalidPassword("Wrong user password") return user @staticmethod - def identity_oauth(oauth_id, input) -> User: - user = global_session.query(OrmUser).filter_by(oauth_id=oauth_id).first() + def identity_oauth(input) -> User: + user = global_session.query(OrmUser).filter( + or_(OrmUser.oauth_id == input["oauth_id"], OrmUser.email == input["email"]) + ).first() if not user: user = OrmUser.create(**input) + if not user.oauth_id: + user.oauth_id = input["oauth_id"] + global_session.commit() + user = User(**user.dict()) return user diff --git a/auth/oauth.py b/auth/oauth.py index 8fc6ade7..8bad3dbc 100644 --- a/auth/oauth.py +++ b/auth/oauth.py @@ -63,6 +63,6 @@ async def oauth_authorize(request): "email" : profile["email"], "username" : profile["name"] } - user = Identity.identity_oauth(oauth_id=oauth_id, input=user_input) + user = Identity.identity_oauth(user_input) token = await Authorize.authorize(user, device="pc", auto_delete=False) return PlainTextResponse(token) diff --git a/resolvers/auth.py b/resolvers/auth.py index cfc7f1a1..3745ddfe 100644 --- a/resolvers/auth.py +++ b/resolvers/auth.py @@ -19,13 +19,17 @@ async def register(*_, input: dict = None) -> User: @query.field("signIn") -async def sign_in(_, info: GraphQLResolveInfo, id: int, password: str): +async def sign_in(_, info: GraphQLResolveInfo, email: str, password: str): + orm_user = global_session.query(User).filter(User.email == email).first() + if orm_user is None: + return {"status" : False, "error" : "invalid email"} + try: device = info.context["request"].headers['device'] except KeyError: device = "pc" auto_delete = False if device == "mobile" else True - user = Identity.identity(user_id=id, password=password) + user = Identity.identity(user_id=orm_user.id, password=password) token = await Authorize.authorize(user, device=device, auto_delete=auto_delete) return {"status" : True, "token" : token} @@ -38,9 +42,16 @@ async def sign_out(_, info: GraphQLResolveInfo): return {"status" : status} -#@query.field("getUser") -#@login_required -async def get_user(*_, id: int): - return global_session.query(User).filter(User.id == id).first() +@query.field("getCurrentUser") +@login_required +async def get_user(_, info): + auth = info.context["request"].auth + user_id = auth.user_id + return global_session.query(User).filter(User.id == user_id).first() + +@query.field("isEmailFree") +async def is_email_free(_, info, email): + user = global_session.query(User).filter(User.email == email).first() + return user is None diff --git a/schema.graphql b/schema.graphql index 73e10313..dc562f4e 100644 --- a/schema.graphql +++ b/schema.graphql @@ -79,11 +79,12 @@ type Mutation { type Query { # auth / user - signIn(id: Int!, password: String!): signInPayload! + signIn(email: String!, password: String!): signInPayload! signOut: ResultPayload! - getCurrentUser: User! - getTokens: [Token!]! - isUsernameFree(username: String!): Boolean! + getCurrentUser: User! + + isEmailFree(email: String!): Boolean! + getOnline: [User!]! getUserById(id: Int!): User! getUserRating(shout: Int): Int!