diff --git a/services/auth.py b/services/auth.py index 479c79de..8b71ca0b 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,9 +1,11 @@ -from cachetools import TTLCache, cached +from functools import wraps import logging import time -from starlette.exceptions import HTTPException from aiohttp import ClientSession +from starlette.exceptions import HTTPException +from cachetools import TTLCache + from settings import AUTH_URL, AUTH_SECRET @@ -11,8 +13,8 @@ logging.basicConfig() logger = logging.getLogger("\t[services.auth]\t") logger.setLevel(logging.DEBUG) -# Define a TTLCache with a time-to-live of 100 seconds -token_cache = TTLCache(maxsize=99999, ttl=1799) +# Define a TTLCache with a time-to-live of 1800 seconds +token_cache = TTLCache(maxsize=99999, ttl=1800) async def request_data(gql, headers={"Content-Type": "application/json"}): try: @@ -29,7 +31,6 @@ async def request_data(gql, headers={"Content-Type": "application/json"}): logger.error(f"[services.auth] request_data error: {e}") return None -@cached(cache=token_cache) async def user_id_from_token(token): logger.error(f"[services.auth] checking auth token: {token}") query_name = "validate_jwt_token" @@ -60,11 +61,21 @@ async def user_id_from_token(token): async def check_auth(req) -> str | None: token = req.headers.get("Authorization") - cached_result = await user_id_from_token(token) + # Manually manage cache using a dictionary + cached_result = token_cache.get(token) if cached_result: user_id, expires_at = cached_result if expires_at > time.time(): return user_id + + # If not in cache, fetch from user_id_from_token and update cache + result = await user_id_from_token(token) + if result: + user_id, expires_at = result + token_cache[token] = (user_id, expires_at) + if expires_at > time.time(): + return user_id + raise HTTPException(status_code=401, detail="Unauthorized") async def add_user_role(user_id): @@ -84,6 +95,7 @@ async def add_user_role(user_id): return user_id def login_required(f): + @wraps(f) async def decorated_function(*args, **kwargs): info = args[1] context = info.context @@ -95,7 +107,9 @@ def login_required(f): return decorated_function + def auth_request(f): + @wraps(f) async def decorated_function(*args, **kwargs): req = args[0] user_id = await check_auth(req)