This commit is contained in:
tonyrewin 2022-09-14 07:42:31 +03:00
parent 4446255f64
commit 80a1aeb767

View File

@ -1,141 +1,141 @@
from functools import wraps from functools import wraps
from typing import Optional, Tuple from typing import Optional, Tuple
from datetime import datetime, timedelta from datetime import datetime, timedelta
from graphql import GraphQLResolveInfo from graphql import GraphQLResolveInfo
from jwt import DecodeError, ExpiredSignatureError from jwt import DecodeError, ExpiredSignatureError
from starlette.authentication import AuthenticationBackend from starlette.authentication import AuthenticationBackend
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from auth.credentials import AuthCredentials, AuthUser from auth.credentials import AuthCredentials, AuthUser
from auth.jwtcodec import JWTCodec from auth.jwtcodec import JWTCodec
from auth.authorize import Authorize, TokenStorage from auth.authorize import Authorize, TokenStorage
from base.exceptions import InvalidToken from base.exceptions import InvalidToken
from orm.user import User from orm.user import User
from services.auth.users import UserStorage from services.auth.users import UserStorage
from base.orm import local_session from base.orm import local_session
from settings import JWT_AUTH_HEADER, EMAIL_TOKEN_LIFE_SPAN from settings import JWT_AUTH_HEADER, EMAIL_TOKEN_LIFE_SPAN
class _Authenticate: class _Authenticate:
@classmethod @classmethod
async def verify(cls, token: str): async def verify(cls, token: str):
""" """
Rules for a token to be valid. Rules for a token to be valid.
1. token format is legal && 1. token format is legal &&
token exists in redis database && token exists in redis database &&
token is not expired token is not expired
2. token format is legal && 2. token format is legal &&
token exists in redis database && token exists in redis database &&
token is expired && token is expired &&
token is of specified type token is of specified type
""" """
try: try:
payload = JWTCodec.decode(token) payload = JWTCodec.decode(token)
except ExpiredSignatureError: except ExpiredSignatureError:
payload = JWTCodec.decode(token, verify_exp=False) payload = JWTCodec.decode(token, verify_exp=False)
if not await cls.exists(payload.user_id, token): if not await cls.exists(payload.user_id, token):
raise InvalidToken("Login expired, please login again") raise InvalidToken("Login expired, please login again")
if payload.device == "mobile": # noqa if payload.device == "mobile": # noqa
"we cat set mobile token to be valid forever" "we cat set mobile token to be valid forever"
return payload return payload
except DecodeError as e: except DecodeError as e:
raise InvalidToken("token format error") from e raise InvalidToken("token format error") from e
else: else:
if not await cls.exists(payload.user_id, token): if not await cls.exists(payload.user_id, token):
raise InvalidToken("Login expired, please login again") raise InvalidToken("Login expired, please login again")
return payload return payload
@classmethod @classmethod
async def exists(cls, user_id, token): async def exists(cls, user_id, token):
return await TokenStorage.exist(f"{user_id}-{token}") return await TokenStorage.exist(f"{user_id}-{token}")
class JWTAuthenticate(AuthenticationBackend): class JWTAuthenticate(AuthenticationBackend):
async def authenticate( async def authenticate(
self, request: HTTPConnection self, request: HTTPConnection
) -> Optional[Tuple[AuthCredentials, AuthUser]]: ) -> Optional[Tuple[AuthCredentials, AuthUser]]:
if JWT_AUTH_HEADER not in request.headers: if JWT_AUTH_HEADER not in request.headers:
return AuthCredentials(scopes=[]), AuthUser(user_id=None) return AuthCredentials(scopes=[]), AuthUser(user_id=None)
token = request.headers[JWT_AUTH_HEADER] token = request.headers[JWT_AUTH_HEADER]
try: try:
payload = await _Authenticate.verify(token) payload = await _Authenticate.verify(token)
except Exception as exc: except Exception as exc:
return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser( return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(
user_id=None user_id=None
) )
if payload is None: if payload is None:
return AuthCredentials(scopes=[]), AuthUser(user_id=None) return AuthCredentials(scopes=[]), AuthUser(user_id=None)
if not payload.device in ("pc", "mobile"): if payload.device not in ("pc", "mobile"):
return AuthCredentials(scopes=[]), AuthUser(user_id=None) return AuthCredentials(scopes=[]), AuthUser(user_id=None)
user = await UserStorage.get_user(payload.user_id) user = await UserStorage.get_user(payload.user_id)
if not user: if not user:
return AuthCredentials(scopes=[]), AuthUser(user_id=None) return AuthCredentials(scopes=[]), AuthUser(user_id=None)
scopes = await user.get_permission() scopes = await user.get_permission()
return ( return (
AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True),
user, user,
) )
class EmailAuthenticate: class EmailAuthenticate:
@staticmethod @staticmethod
async def get_email_token(user): async def get_email_token(user):
token = await Authorize.authorize( token = await Authorize.authorize(
user, device="email", life_span=EMAIL_TOKEN_LIFE_SPAN user, device="email", life_span=EMAIL_TOKEN_LIFE_SPAN
) )
return token return token
@staticmethod @staticmethod
async def authenticate(token): async def authenticate(token):
payload = await _Authenticate.verify(token) payload = await _Authenticate.verify(token)
if payload is None: if payload is None:
raise InvalidToken("invalid token") raise InvalidToken("invalid token")
if payload.device != "email": if payload.device != "email":
raise InvalidToken("invalid token") raise InvalidToken("invalid token")
with local_session() as session: with local_session() as session:
user = session.query(User).filter_by(id=payload.user_id).first() user = session.query(User).filter_by(id=payload.user_id).first()
if not user: if not user:
raise Exception("user not exist") raise Exception("user not exist")
if not user.emailConfirmed: if not user.emailConfirmed:
user.emailConfirmed = True user.emailConfirmed = True
session.commit() session.commit()
auth_token = await Authorize.authorize(user) auth_token = await Authorize.authorize(user)
return (auth_token, user) return (auth_token, user)
class ResetPassword: class ResetPassword:
@staticmethod @staticmethod
async def get_reset_token(user): async def get_reset_token(user):
exp = datetime.utcnow() + timedelta(seconds=EMAIL_TOKEN_LIFE_SPAN) exp = datetime.utcnow() + timedelta(seconds=EMAIL_TOKEN_LIFE_SPAN)
token = JWTCodec.encode(user, exp=exp, device="pc") token = JWTCodec.encode(user, exp=exp, device="pc")
await TokenStorage.save(f"{user.id}-reset-{token}", EMAIL_TOKEN_LIFE_SPAN, True) await TokenStorage.save(f"{user.id}-reset-{token}", EMAIL_TOKEN_LIFE_SPAN, True)
return token return token
@staticmethod @staticmethod
async def verify(token): async def verify(token):
try: try:
payload = JWTCodec.decode(token) payload = JWTCodec.decode(token)
except ExpiredSignatureError: except ExpiredSignatureError:
raise InvalidToken("Login expired, please login again") raise InvalidToken("Login expired, please login again")
except DecodeError as e: except DecodeError as e:
raise InvalidToken("token format error") from e raise InvalidToken("token format error") from e
else: else:
if not await TokenStorage.exist(f"{payload.user_id}-reset-{token}"): if not await TokenStorage.exist(f"{payload.user_id}-reset-{token}"):
raise InvalidToken("Login expired, please login again") raise InvalidToken("Login expired, please login again")
return payload.user_id return payload.user_id
def login_required(func): def login_required(func):
@wraps(func) @wraps(func)
async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs):
auth: AuthCredentials = info.context["request"].auth auth: AuthCredentials = info.context["request"].auth
if not auth.logged_in: if not auth.logged_in:
return {"error": auth.error_message or "Please login"} return {"error": auth.error_message or "Please login"}
return await func(parent, info, *args, **kwargs) return await func(parent, info, *args, **kwargs)
return wrap return wrap