initial-draft
This commit is contained in:
64
services/auth.py
Normal file
64
services/auth.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from functools import wraps
|
||||
from httpx import AsyncClient, HTTPError
|
||||
from settings import AUTH_URL
|
||||
|
||||
|
||||
async def check_auth(req):
|
||||
token = req.headers.get("Authorization")
|
||||
headers = {"Authorization": token, "Content-Type": "application/json"} # "Bearer " + removed
|
||||
print(f"[services.auth] checking auth token: {token}")
|
||||
|
||||
query_name = "getSession" if "v2." in AUTH_URL else "session"
|
||||
query_type = "mutation" if "v2." in AUTH_URL else "query"
|
||||
operation = "GetUserId"
|
||||
|
||||
gql = {
|
||||
"query": query_type + " " + operation + " { " + query_name + " { user { id } } " + " }",
|
||||
"operationName": operation,
|
||||
"variables": None,
|
||||
}
|
||||
|
||||
async with AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(AUTH_URL, headers=headers, json=gql)
|
||||
print(f"[services.auth] {AUTH_URL} response: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
return False, None
|
||||
r = response.json()
|
||||
if r:
|
||||
user_id = r.get("data", {}).get(query_name, {}).get("user", {}).get("id", None)
|
||||
is_authenticated = user_id is not None
|
||||
return is_authenticated, user_id
|
||||
return False, None
|
||||
|
||||
|
||||
def login_required(f):
|
||||
@wraps(f)
|
||||
async def decorated_function(*args, **kwargs):
|
||||
info = args[1]
|
||||
context = info.context
|
||||
req = context.get("request")
|
||||
is_authenticated, user_id = await check_auth(req)
|
||||
if not is_authenticated:
|
||||
raise Exception("You are not logged in")
|
||||
else:
|
||||
# Добавляем author_id в контекст
|
||||
context["user_id"] = user_id
|
||||
|
||||
# Если пользователь аутентифицирован, выполняем резолвер
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def auth_request(f):
|
||||
@wraps(f)
|
||||
async def decorated_function(*args, **kwargs):
|
||||
req = args[0]
|
||||
is_authenticated, user_id = await check_auth(req)
|
||||
if not is_authenticated:
|
||||
raise HTTPError("please, login first")
|
||||
else:
|
||||
req["author_id"] = user_id
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
62
services/core.py
Normal file
62
services/core.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from httpx import AsyncClient
|
||||
from settings import API_BASE
|
||||
from typing import List
|
||||
from models.member import ChatMember
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
async def get_all_authors() -> List[ChatMember]:
|
||||
query_name = "authorsAll"
|
||||
query_type = "query"
|
||||
operation = "AuthorsAll"
|
||||
query_fields = "id slug userpic name"
|
||||
|
||||
gql = {
|
||||
"query": query_type + " " + operation + " { " + query_name + " { " + query_fields + " } " + " }",
|
||||
"operationName": operation,
|
||||
"variables": None,
|
||||
}
|
||||
|
||||
async with AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(API_BASE, headers=headers, json=gql)
|
||||
print(f"[services.core] {query_name}: [{response.status_code}] {len(response.text)} bytes")
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
r = response.json()
|
||||
if r:
|
||||
return r.get("data", {}).get(query_name, [])
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
|
||||
async def get_my_followings() -> List[ChatMember]:
|
||||
query_name = "loadMySubscriptions"
|
||||
query_type = "query"
|
||||
operation = "LoadMySubscriptions"
|
||||
query_fields = "id slug userpic name"
|
||||
|
||||
gql = {
|
||||
"query": query_type + " " + operation + " { " + query_name + " { authors {" + query_fields + "} } " + " }",
|
||||
"operationName": operation,
|
||||
"variables": None,
|
||||
}
|
||||
|
||||
async with AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(API_BASE, headers=headers, json=gql)
|
||||
print(f"[services.core] {query_name}: [{response.status_code}] {len(response.text)} bytes")
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
r = response.json()
|
||||
if r:
|
||||
return r.get("data", {}).get(query_name, {}).get("authors", [])
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return []
|
67
services/db.py
Normal file
67
services/db.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, TypeVar
|
||||
# from psycopg2.errors import UniqueViolation
|
||||
from sqlalchemy import Column, Integer, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.schema import Table
|
||||
|
||||
from settings import DB_URL
|
||||
|
||||
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
REGISTRY: Dict[str, type] = {}
|
||||
|
||||
|
||||
# @contextmanager
|
||||
def local_session(src=""):
|
||||
return Session(bind=engine, expire_on_commit=False)
|
||||
|
||||
# try:
|
||||
# yield session
|
||||
# session.commit()
|
||||
# except Exception as e:
|
||||
# if not (src == "create_shout" and isinstance(e, UniqueViolation)):
|
||||
# import traceback
|
||||
|
||||
# session.rollback()
|
||||
# print(f"[services.db] {src}: {e}")
|
||||
|
||||
# traceback.print_exc()
|
||||
|
||||
# raise Exception("[services.db] exception")
|
||||
|
||||
# finally:
|
||||
# session.close()
|
||||
|
||||
|
||||
class Base(declarative_base()):
|
||||
__table__: Table
|
||||
__tablename__: str
|
||||
__new__: Callable
|
||||
__init__: Callable
|
||||
__allow_unmapped__ = True
|
||||
__abstract__ = True
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
REGISTRY[cls.__name__] = cls
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
column_names = self.__table__.columns.keys()
|
||||
if "_sa_instance_state" in column_names:
|
||||
column_names.remove("_sa_instance_state")
|
||||
try:
|
||||
return {c: getattr(self, c) for c in column_names}
|
||||
except Exception as e:
|
||||
print(f"[services.db] Error dict: {e}")
|
||||
return {}
|
||||
|
||||
def update(self, values: Dict[str, Any]) -> None:
|
||||
for key, value in values.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
61
services/rediscache.py
Normal file
61
services/rediscache.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import redis.asyncio as aredis
|
||||
from settings import REDIS_URL
|
||||
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self, uri=REDIS_URL):
|
||||
self._uri: str = uri
|
||||
self.pubsub_channels = []
|
||||
self._client = None
|
||||
|
||||
async def connect(self):
|
||||
self._client = aredis.Redis.from_url(self._uri, decode_responses=True)
|
||||
|
||||
async def disconnect(self):
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
|
||||
async def execute(self, command, *args, **kwargs):
|
||||
if self._client:
|
||||
try:
|
||||
print("[redis] " + command + " " + " ".join(args))
|
||||
r = await self._client.execute_command(command, *args, **kwargs)
|
||||
return r
|
||||
except Exception as e:
|
||||
print(f"[redis] error: {e}")
|
||||
return None
|
||||
|
||||
async def subscribe(self, *channels):
|
||||
if self._client:
|
||||
async with self._client.pubsub() as pubsub:
|
||||
for channel in channels:
|
||||
await pubsub.subscribe(channel)
|
||||
self.pubsub_channels.append(channel)
|
||||
|
||||
async def unsubscribe(self, *channels):
|
||||
if not self._client:
|
||||
return
|
||||
async with self._client.pubsub() as pubsub:
|
||||
for channel in channels:
|
||||
await pubsub.unsubscribe(channel)
|
||||
self.pubsub_channels.remove(channel)
|
||||
|
||||
async def publish(self, channel, data):
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.publish(channel, data)
|
||||
|
||||
async def lrange(self, key, start, stop):
|
||||
if self._client:
|
||||
print(f"[redis] LRANGE {key} {start} {stop}")
|
||||
return await self._client.lrange(key, start, stop)
|
||||
|
||||
async def mget(self, key, *keys):
|
||||
if self._client:
|
||||
print(f"[redis] MGET {key} {keys}")
|
||||
return await self._client.mget(key, *keys)
|
||||
|
||||
|
||||
redis = RedisCache()
|
||||
|
||||
__all__ = ["redis"]
|
5
services/schema.py
Normal file
5
services/schema.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ariadne import QueryType, MutationType
|
||||
|
||||
query = QueryType()
|
||||
mutation = MutationType()
|
||||
resolvers = [query, mutation]
|
Reference in New Issue
Block a user