diff --git a/main.py b/main.py index 9950c13..b332880 100644 --- a/main.py +++ b/main.py @@ -3,9 +3,8 @@ from importlib import import_module from os.path import exists from ariadne import load_schema_from_path, make_executable_schema -from starlette.applications import Starlette from ariadne.asgi import GraphQL - +from starlette.applications import Starlette from starlette.routing import Route from services.logger import root_logger as logger @@ -26,6 +25,7 @@ async def start(): f.write(str(os.getpid())) logger.info(f"process started in {MODE} mode") + # main starlette app object with ariadne mounted in root app = Starlette( on_startup=[ diff --git a/resolvers/__init__.py b/resolvers/__init__.py index 9f4b2b9..cdd4679 100644 --- a/resolvers/__init__.py +++ b/resolvers/__init__.py @@ -1,11 +1,7 @@ from resolvers.chats import create_chat, delete_chat, update_chat from resolvers.load import load_chats, load_messages_by -from resolvers.messages import ( - create_message, - delete_message, - mark_as_read, - update_message, -) +from resolvers.messages import (create_message, delete_message, mark_as_read, + update_message) from resolvers.search import search_messages, search_recipients __all__ = [ diff --git a/services/auth.py b/services/auth.py index 30f66a7..a205679 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,83 +1,94 @@ -import logging from functools import wraps -from aiohttp import ClientSession +import httpx from starlette.exceptions import HTTPException from services.core import get_author_by_user from services.logger import root_logger as logger from settings import AUTH_URL -logger.setLevel(logging.DEBUG) + +async def request_data(gql, headers=None): + if headers is None: + headers = {"Content-Type": "application/json"} + try: + async with httpx.AsyncClient() as client: + response = await client.post(AUTH_URL, json=gql, headers=headers) + if response.status_code == 200: + data = response.json() + errors = data.get("errors") + if errors: + logger.error(f"HTTP Errors: {errors}") + else: + return data + except Exception as e: + # Handling and logging exceptions during authentication check + logger.error(f"request_data error: {e}") + return None async def check_auth(req): - logger.debug("checking auth...") + token = req.headers.get("Authorization") user_id = "" - try: - token = req.headers.get("Authorization") - if token: - # Logging the authentication token - query_name = "validate_jwt_token" - operation = "ValidateToken" - headers = { - "Content-Type": "application/json", - } + user_roles = [] + if token: + # Logging the authentication token + logger.debug(f"{token}") + query_name = "validate_jwt_token" + operation = "ValidateToken" + variables = {"params": {"token_type": "access_token", "token": token}} - variables = { - "params": { - "token_type": "access_token", - "token": token, - } - } - - gql = { - "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", - "variables": variables, - "operationName": operation, - } - # Asynchronous HTTP request to the authentication server - async with ClientSession() as session: - async with session.post( - AUTH_URL, json=gql, headers=headers - ) as response: - if response.status == 200: - data = await response.json() - errors = data.get("errors") - if errors: - logger.error(f"{errors}") - else: - user_id = ( - data.get("data", {}) - .get(query_name, {}) - .get("claims", {}) - .get("sub") - ) - logger.info(f"got user_id: {user_id}") - return user_id - except Exception as e: - # Handling and logging exceptions during authentication check - logger.error(e) - - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") + gql = { + "query": f"query {operation}($params: ValidateJWTTokenInput!) {{" + + f"{query_name}(params: $params) {{ is_valid claims }} " + + "}", + "variables": variables, + "operationName": operation, + } + data = await request_data(gql) + if data: + logger.debug(data) + user_data = data.get("data", {}).get(query_name, {}).get("claims", {}) + user_id = user_data.get("sub", "") + user_roles = user_data.get("allowed_roles", []) + return user_id, user_roles def login_required(f): @wraps(f) async def decorated_function(*args, **kwargs): info = args[1] - context = info.context - req = context.get("request") - user_id = await check_auth(req) - if user_id: - context["user_id"] = user_id.strip() - author = get_author_by_user(user_id) - if author and "id" in author: - context["author_id"] = author["id"] - else: - logger.debug(author) - HTTPException(status_code=401, detail="Unauthorized") + req = info.context.get("request") + authorized = await check_auth(req) + if authorized: + logger.info(authorized) + user_id, user_roles = authorized + if user_id and user_roles: + logger.info(f" got {user_id} roles: {user_roles}") + info.context["user_id"] = user_id.strip() return await f(*args, **kwargs) return decorated_function + + +def auth_request(f): + @wraps(f) + async def decorated_function(*args, **kwargs): + req = args[0] + authorized = await check_auth(req) + if authorized: + user_id, user_roles = authorized + if user_id and user_roles: + logger.info(f" got {user_id} roles: {user_roles}") + req["user_id"] = user_id.strip() + author = get_author_by_user(user_id) + if author and "id" in author: + req["author_id"] = author["id"] + else: + logger.debug(author) + HTTPException(status_code=404, detail="Cannot find author profile") + return await f(*args, **kwargs) + else: + raise HTTPException(status_code=401, detail="Unauthorized") + + return decorated_function