refresh-token

This commit is contained in:
tonyrewin 2022-11-24 17:31:52 +03:00
parent 786bd20275
commit 84600308ad
6 changed files with 54 additions and 44 deletions

View File

@ -2,48 +2,14 @@ from functools import wraps
from typing import Optional, Tuple
from graphql.type import GraphQLResolveInfo
from jwt import DecodeError, ExpiredSignatureError
from starlette.authentication import AuthenticationBackend
from starlette.requests import HTTPConnection
from auth.credentials import AuthCredentials, AuthUser
from auth.jwtcodec import JWTCodec
from auth.tokenstorage import TokenStorage
from base.exceptions import ExpiredToken, InvalidToken
from services.auth.users import UserStorage
from settings import SESSION_TOKEN_HEADER
class SessionToken:
@classmethod
async def verify(cls, token: str):
"""
Rules for a token to be valid.
1. token format is legal &&
token exists in redis database &&
token is not expired
2. token format is legal &&
token exists in redis database &&
token is expired &&
token is of specified type
"""
try:
print('[auth.authenticate] session token verify')
payload = JWTCodec.decode(token)
except ExpiredSignatureError:
payload = JWTCodec.decode(token, verify_exp=False)
if not await cls.get(payload.user_id, token):
raise ExpiredToken("Token signature has expired, please try again")
except DecodeError as e:
raise InvalidToken("token format error") from e
else:
if not await cls.get(payload.user_id, token):
raise ExpiredToken("Session token has expired, please login again")
return payload
@classmethod
async def get(cls, uid, token):
return await TokenStorage.get(f"{uid}-{token}")
from auth.tokenstorage import SessionToken
from base.exceptions import InvalidToken
class JWTAuthenticate(AuthenticationBackend):
@ -54,10 +20,18 @@ class JWTAuthenticate(AuthenticationBackend):
if SESSION_TOKEN_HEADER not in request.headers:
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
token = request.headers.get(SESSION_TOKEN_HEADER, "")
token = request.headers.get(SESSION_TOKEN_HEADER)
if not token:
print("[auth.authenticate] no token in header %s" % SESSION_TOKEN_HEADER)
return AuthCredentials(scopes=[], error_message=str("no token")), AuthUser(
user_id=None
)
try:
payload = await SessionToken.verify(token)
if len(token.split('.')) > 1:
payload = await SessionToken.verify(token)
else:
InvalidToken("please try again")
except Exception as exc:
print("[auth.authenticate] session token verify error")
print(exc)

View File

@ -20,7 +20,7 @@ class AuthCredentials(BaseModel):
return True
async def permissions(self) -> List[Permission]:
if self.user_id is not None:
if self.user_id is None:
raise OperationNotAllowed("Please login first")
return NotImplemented()

View File

@ -8,12 +8,11 @@ from settings import JWT_ALGORITHM, JWT_SECRET_KEY
class JWTCodec:
@staticmethod
def encode(user: AuthInput, exp: datetime) -> str:
issued = datetime.now(tz=timezone.utc)
payload = {
"user_id": user.id,
"username": user.email or user.phone,
"exp": exp,
"iat": issued,
"iat": datetime.now(tz=timezone.utc),
"iss": "discours"
}
try:

View File

@ -13,9 +13,30 @@ async def save(token_key, life_span, auto_delete=True):
await redis.execute("EXPIREAT", token_key, int(expire_at))
class SessionToken:
@classmethod
async def verify(cls, token: str):
"""
Rules for a token to be valid.
- token format is legal
- token exists in redis database
- token is not expired
"""
try:
return JWTCodec.decode(token)
except Exception as e:
raise e
@classmethod
async def get(cls, uid, token):
return await TokenStorage.get(f"{uid}-{token}")
class TokenStorage:
@staticmethod
async def get(token_key):
print('[tokenstorage.get] ' + token_key)
# 2041-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoyMDQxLCJ1c2VybmFtZSI6ImFudG9uLnJld2luK3Rlc3QtbG9hZGNoYXRAZ21haWwuY29tIiwiZXhwIjoxNjcxNzgwNjE2LCJpYXQiOjE2NjkxODg2MTYsImlzcyI6ImRpc2NvdXJzIn0.Nml4oV6iMjMmc6xwM7lTKEZJKBXvJFEIZ-Up1C1rITQ
return await redis.execute("GET", token_key)
@staticmethod

View File

@ -24,13 +24,24 @@ from settings import SESSION_TOKEN_HEADER
@mutation.field("refreshSession")
@login_required
async def get_current_user(_, info):
print('[resolvers.auth] get current user %s' % str(info))
user = info.context["request"].user
# print(info.context["request"].headers)
old_token = info.context["request"].headers.get("Authorization")
user.lastSeen = datetime.now(tz=timezone.utc)
with local_session() as session:
session.add(user)
session.commit()
token = await TokenStorage.create_session(user)
print("[resolvers.auth] new session token created")
if old_token:
payload = await TokenStorage.get(str(user.id) + '-' + str(old_token))
if payload:
print("[resolvers.auth] got session from old token: %r" % payload)
return {
"token": token,
"user": user,
"news": await user_subscriptions(user.slug),
}
return {
"token": token,
"user": user,

View File

@ -53,7 +53,6 @@ if __name__ == "__main__":
if len(sys.argv) > 1:
x = sys.argv[1]
if x == "dev":
print("DEV MODE")
if os.path.exists(DEV_SERVER_STATUS_FILE_NAME):
os.remove(DEV_SERVER_STATUS_FILE_NAME)
@ -67,6 +66,12 @@ if __name__ == "__main__":
("Access-Control-Expose-Headers", "Content-Length,Content-Range"),
("Access-Control-Allow-Credentials", "true"),
]
want_reload = False
if "reload" in sys.argv:
print("MODE: DEV + RELOAD")
want_reload = True
else:
print("MODE: DEV")
uvicorn.run(
"main:dev_app",
host="localhost",
@ -75,7 +80,7 @@ if __name__ == "__main__":
# log_config=LOGGING_CONFIG,
log_level=None,
access_log=False,
reload=True
reload=want_reload
) # , ssl_keyfile="discours.key", ssl_certfile="discours.crt")
elif x == "migrate":
from migration import migrate