isolated
This commit is contained in:
79
services/auth.py
Normal file
79
services/auth.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from functools import wraps
|
||||
from gql.transport import aiohttp
|
||||
import aiohttp
|
||||
import json
|
||||
from services.db import local_session
|
||||
from settings import AUTH_URL
|
||||
from orm.author import Author
|
||||
from graphql.error import GraphQLError
|
||||
|
||||
|
||||
class BaseHttpException(GraphQLError):
|
||||
code = 500
|
||||
message = "500 Server error"
|
||||
|
||||
|
||||
class Unauthorized(BaseHttpException):
|
||||
code = 401
|
||||
message = "401 Unauthorized"
|
||||
|
||||
|
||||
async def check_auth(req):
|
||||
token = req.headers.get("Authorization")
|
||||
gql = (
|
||||
{"mutation": "{ getSession { user { id } } }"}
|
||||
if "v2" in AUTH_URL
|
||||
else {"query": "{ session { user { id } } }"}
|
||||
)
|
||||
headers = {"Authorization": token, "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
async with session.post(AUTH_URL, data=json.dumps(gql)) as response:
|
||||
if response.status != 200:
|
||||
return False, None
|
||||
r = await response.json()
|
||||
user_id = (
|
||||
r.get("data", {}).get("session", {}).get("user", {}).get("id", None)
|
||||
)
|
||||
is_authenticated = user_id is not None
|
||||
return is_authenticated, user_id
|
||||
|
||||
|
||||
def author_id_by_user_id(user_id):
|
||||
async with local_session() as session:
|
||||
author = session(Author).where(Author.user == user_id).first()
|
||||
return author.id
|
||||
|
||||
|
||||
def login_required(f):
|
||||
@wraps(f)
|
||||
async def decorated_function(*args, **kwargs):
|
||||
info = args[1]
|
||||
context = info.context
|
||||
req = context.get("request")
|
||||
is_authenticated, user_id = await check_auth(req)
|
||||
if not is_authenticated:
|
||||
raise Exception("You are not logged in")
|
||||
else:
|
||||
# Добавляем author_id в контекст
|
||||
author_id = await author_id_by_user_id(user_id)
|
||||
context["author_id"] = author_id
|
||||
|
||||
# Если пользователь аутентифицирован, выполняем резолвер
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def auth_request(f):
|
||||
@wraps(f)
|
||||
async def decorated_function(*args, **kwargs):
|
||||
req = args[0]
|
||||
is_authenticated, user_id = await check_auth(req)
|
||||
if not is_authenticated:
|
||||
raise Unauthorized("You are not logged in")
|
||||
else:
|
||||
author_id = await author_id_by_user_id(user_id)
|
||||
req["author_id"] = author_id
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
54
services/db.py
Normal file
54
services/db.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import TypeVar, Any, Dict, Generic, Callable
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.schema import Table
|
||||
|
||||
from settings import DB_URL
|
||||
|
||||
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
REGISTRY: Dict[str, type] = {}
|
||||
|
||||
|
||||
def local_session():
|
||||
return Session(bind=engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(declarative_base()):
|
||||
__table__: Table
|
||||
__tablename__: str
|
||||
__new__: Callable
|
||||
__init__: Callable
|
||||
__allow_unmapped__ = True
|
||||
__abstract__ = True
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
REGISTRY[cls.__name__] = cls
|
||||
|
||||
@classmethod
|
||||
def create(cls: Generic[T], **kwargs) -> Generic[T]:
|
||||
instance = cls(**kwargs)
|
||||
return instance.save()
|
||||
|
||||
def save(self) -> Generic[T]:
|
||||
with local_session() as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def update(self, input):
|
||||
column_names = self.__table__.columns.keys()
|
||||
for name, value in input.items():
|
||||
if name in column_names:
|
||||
setattr(self, name, value)
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
column_names = self.__table__.columns.keys()
|
||||
return {c: getattr(self, c) for c in column_names}
|
56
services/redis.py
Normal file
56
services/redis.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
import aioredis
|
||||
from settings import REDIS_URL
|
||||
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self, uri=REDIS_URL):
|
||||
self._uri: str = uri
|
||||
self.pubsub_channels = []
|
||||
self._redis = None
|
||||
|
||||
async def connect(self):
|
||||
pool = aioredis.ConnectionPool.from_url(
|
||||
self._uri, encoding="utf-8", max_connections=10
|
||||
)
|
||||
self._redis = aioredis.Redis(connection_pool=pool)
|
||||
|
||||
async def disconnect(self):
|
||||
await self._redis.wait_closed()
|
||||
self._redis = None
|
||||
|
||||
async def execute(self, command, *args, **kwargs):
|
||||
while not self._redis:
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
print("[redis] " + command + " " + " ".join(args))
|
||||
return await self._redis.execute_command(command, *args, **kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def subscribe(self, *channels):
|
||||
if not self._redis:
|
||||
await self.connect()
|
||||
for channel in channels:
|
||||
await self._redis.execute_pubsub("SUBSCRIBE", channel)
|
||||
self.pubsub_channels.append(channel)
|
||||
|
||||
async def unsubscribe(self, *channels):
|
||||
if not self._redis:
|
||||
return
|
||||
for channel in channels:
|
||||
await self._redis.execute_pubsub("UNSUBSCRIBE", channel)
|
||||
self.pubsub_channels.remove(channel)
|
||||
|
||||
async def lrange(self, key, start, stop):
|
||||
print(f"[redis] LRANGE {key} {start} {stop}")
|
||||
return await self._redis.lrange(key, start, stop)
|
||||
|
||||
async def mget(self, key, *keys):
|
||||
print(f"[redis] MGET {key} {keys}")
|
||||
return await self._redis.mget(key, *keys)
|
||||
|
||||
|
||||
redis = RedisCache()
|
||||
|
||||
__all__ = ["redis"]
|
Reference in New Issue
Block a user