diff --git a/main.py b/main.py index 2283d84..0c4b5ee 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from sentry_sdk.integrations.strawberry import StrawberryIntegration from strawberry.asgi import GraphQL from starlette.applications import Starlette +from services.auth import TokenMiddleware from services.rediscache import redis from resolvers.listener import reactions_worker from resolvers.schema import schema @@ -49,4 +50,5 @@ async def shutdown(): app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) +app.add_middleware(TokenMiddleware) app.mount("/", GraphQL(schema, debug=True)) diff --git a/orm/notification.py b/orm/notification.py index dedd5b4..3b829e8 100644 --- a/orm/notification.py +++ b/orm/notification.py @@ -1,8 +1,8 @@ from enum import Enum as Enumeration import time -from sqlalchemy import Boolean, Column, Enum, Integer, ForeignKey, JSON as JSONType +from sqlalchemy import Column, Enum, Integer, ForeignKey, JSON as JSONType from sqlalchemy.orm import relationship -# from sqlalchemy.dialects.postgresql import JSONB + from orm.author import Author from services.db import Base diff --git a/pyproject.toml b/pyproject.toml index 6386af2..9db6f3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ aiohttp = "^3.9.1" [tool.poetry.dev-dependencies] pytest = "^7.4.2" -black = { version = "^23.9.1", python = ">=3.12" } +black = { version = "^23.12.0", python = ">=3.12" } +ruff = { version = "^0.1.8", python = ">=3.12" } mypy = { version = "^1.7", python = ">=3.12" } setuptools = "^69.0.2" @@ -74,3 +75,12 @@ executionEnvironments = [] python_version = "3.12" warn_unused_configs = true plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"] + +[tool.ruff] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] +line-length = 120 +target-version = "py312" diff --git a/resolvers/schema.py b/resolvers/schema.py index 7290cab..ea6aa6f 100644 --- a/resolvers/schema.py +++ b/resolvers/schema.py @@ -2,12 +2,16 @@ from typing import List from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.exc import SQLAlchemyError + +from orm.author import Author from orm.notification import Notification as NotificationMessage, NotificationSeen -from services.auth import login_required +from services.auth import check_auth +from aiohttp.web import HTTPUnauthorized from services.db import local_session -from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper import strawberry +from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper from strawberry.schema.config import StrawberryConfig +from strawberry.extensions import Extension import logging strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() @@ -70,7 +74,6 @@ def get_notifications(author_id: int, session, limit: int, offset: int) -> List[ @strawberry.type class Query: - @login_required @strawberry.field async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult: author_id = info.context.get("author_id") @@ -95,7 +98,6 @@ class Query: @strawberry.type class Mutation: - @login_required @strawberry.mutation async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult: author_id = info.context.get("author_id") @@ -113,7 +115,6 @@ class Mutation: return NotificationSeenResult(error="cant mark as read") return NotificationSeenResult() - @login_required @strawberry.mutation async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult: author_id = info.context.get("author_id") @@ -135,4 +136,23 @@ class Mutation: return NotificationSeenResult() -schema = strawberry.Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) +class LoginRequiredMiddleware(Extension): + async def on_request_start(self): + context = self.execution_context.context + req = context.get("request") + is_authenticated, user_id = await check_auth(req) + if not is_authenticated: + raise HTTPUnauthorized(text="Please, login first") + else: + with local_session() as session: + author = session.query(Author).filter(Author.user == user_id).first() + if author: + context["author_id"] = author.id + if user_id: + context["user_id"] = user_id + context["user_id"] = user_id + + +schema = strawberry.Schema( + query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] +) diff --git a/services/db.py b/services/db.py index fa60760..e9b15d6 100644 --- a/services/db.py +++ b/services/db.py @@ -1,5 +1,6 @@ # from contextlib import contextmanager from typing import Any, Callable, Dict, TypeVar + # from psycopg2.errors import UniqueViolation from sqlalchemy import Column, Integer, create_engine from sqlalchemy.ext.declarative import declarative_base @@ -18,7 +19,7 @@ REGISTRY: Dict[str, type] = {} # @contextmanager def local_session(src=""): return Session(bind=engine, expire_on_commit=False) - + # try: # yield session # session.commit() @@ -60,8 +61,8 @@ class Base(declarative_base()): except Exception as e: print(f"[services.db] Error dict: {e}") return {} - + def update(self, values: Dict[str, Any]) -> None: - for key, value in values.items(): - if hasattr(self, key): - setattr(self, key, value) \ No newline at end of file + for key, value in values.items(): + if hasattr(self, key): + setattr(self, key, value) diff --git a/services/rediscache.py b/services/rediscache.py index 23ad6bb..dbe0519 100644 --- a/services/rediscache.py +++ b/services/rediscache.py @@ -55,9 +55,9 @@ class RedisCache: while True: message = await pubsub.get_message() - if message and isinstance(message['data'], (str, bytes, bytearray)): + if message and isinstance(message["data"], (str, bytes, bytearray)): try: - yield json.loads(message['data']) + yield json.loads(message["data"]) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") await asyncio.sleep(0.1)