initial-draft

This commit is contained in:
2023-11-24 01:58:55 +03:00
parent 7304735041
commit ec20a4ebcd
22 changed files with 996 additions and 0 deletions

64
services/auth.py Normal file
View File

@@ -0,0 +1,64 @@
from functools import wraps
from httpx import AsyncClient, HTTPError
from settings import AUTH_URL
async def check_auth(req):
token = req.headers.get("Authorization")
headers = {"Authorization": token, "Content-Type": "application/json"} # "Bearer " + removed
print(f"[services.auth] checking auth token: {token}")
query_name = "getSession" if "v2." in AUTH_URL else "session"
query_type = "mutation" if "v2." in AUTH_URL else "query"
operation = "GetUserId"
gql = {
"query": query_type + " " + operation + " { " + query_name + " { user { id } } " + " }",
"operationName": operation,
"variables": None,
}
async with AsyncClient(timeout=30.0) as client:
response = await client.post(AUTH_URL, headers=headers, json=gql)
print(f"[services.auth] {AUTH_URL} response: {response.status_code}")
if response.status_code != 200:
return False, None
r = response.json()
if r:
user_id = r.get("data", {}).get(query_name, {}).get("user", {}).get("id", None)
is_authenticated = user_id is not None
return is_authenticated, user_id
return False, None
def login_required(f):
@wraps(f)
async def decorated_function(*args, **kwargs):
info = args[1]
context = info.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if not is_authenticated:
raise Exception("You are not logged in")
else:
# Добавляем author_id в контекст
context["user_id"] = user_id
# Если пользователь аутентифицирован, выполняем резолвер
return await f(*args, **kwargs)
return decorated_function
def auth_request(f):
@wraps(f)
async def decorated_function(*args, **kwargs):
req = args[0]
is_authenticated, user_id = await check_auth(req)
if not is_authenticated:
raise HTTPError("please, login first")
else:
req["author_id"] = user_id
return await f(*args, **kwargs)
return decorated_function

62
services/core.py Normal file
View File

@@ -0,0 +1,62 @@
from httpx import AsyncClient
from settings import API_BASE
from typing import List
from models.member import ChatMember
headers = {"Content-Type": "application/json"}
async def get_all_authors() -> List[ChatMember]:
query_name = "authorsAll"
query_type = "query"
operation = "AuthorsAll"
query_fields = "id slug userpic name"
gql = {
"query": query_type + " " + operation + " { " + query_name + " { " + query_fields + " } " + " }",
"operationName": operation,
"variables": None,
}
async with AsyncClient() as client:
try:
response = await client.post(API_BASE, headers=headers, json=gql)
print(f"[services.core] {query_name}: [{response.status_code}] {len(response.text)} bytes")
if response.status_code != 200:
return []
r = response.json()
if r:
return r.get("data", {}).get(query_name, [])
except Exception:
import traceback
traceback.print_exc()
return []
async def get_my_followings() -> List[ChatMember]:
query_name = "loadMySubscriptions"
query_type = "query"
operation = "LoadMySubscriptions"
query_fields = "id slug userpic name"
gql = {
"query": query_type + " " + operation + " { " + query_name + " { authors {" + query_fields + "} } " + " }",
"operationName": operation,
"variables": None,
}
async with AsyncClient() as client:
try:
response = await client.post(API_BASE, headers=headers, json=gql)
print(f"[services.core] {query_name}: [{response.status_code}] {len(response.text)} bytes")
if response.status_code != 200:
return []
r = response.json()
if r:
return r.get("data", {}).get(query_name, {}).get("authors", [])
except Exception:
import traceback
traceback.print_exc()
return []

67
services/db.py Normal file
View File

@@ -0,0 +1,67 @@
# from contextlib import contextmanager
from typing import Any, Callable, Dict, TypeVar
# from psycopg2.errors import UniqueViolation
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql.schema import Table
from settings import DB_URL
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
T = TypeVar("T")
REGISTRY: Dict[str, type] = {}
# @contextmanager
def local_session(src=""):
return Session(bind=engine, expire_on_commit=False)
# try:
# yield session
# session.commit()
# except Exception as e:
# if not (src == "create_shout" and isinstance(e, UniqueViolation)):
# import traceback
# session.rollback()
# print(f"[services.db] {src}: {e}")
# traceback.print_exc()
# raise Exception("[services.db] exception")
# finally:
# session.close()
class Base(declarative_base()):
__table__: Table
__tablename__: str
__new__: Callable
__init__: Callable
__allow_unmapped__ = True
__abstract__ = True
__table_args__ = {"extend_existing": True}
id = Column(Integer, primary_key=True)
def __init_subclass__(cls, **kwargs):
REGISTRY[cls.__name__] = cls
def dict(self) -> Dict[str, Any]:
column_names = self.__table__.columns.keys()
if "_sa_instance_state" in column_names:
column_names.remove("_sa_instance_state")
try:
return {c: getattr(self, c) for c in column_names}
except Exception as e:
print(f"[services.db] Error dict: {e}")
return {}
def update(self, values: Dict[str, Any]) -> None:
for key, value in values.items():
if hasattr(self, key):
setattr(self, key, value)

61
services/rediscache.py Normal file
View File

@@ -0,0 +1,61 @@
import redis.asyncio as aredis
from settings import REDIS_URL
class RedisCache:
def __init__(self, uri=REDIS_URL):
self._uri: str = uri
self.pubsub_channels = []
self._client = None
async def connect(self):
self._client = aredis.Redis.from_url(self._uri, decode_responses=True)
async def disconnect(self):
if self._client:
await self._client.close()
async def execute(self, command, *args, **kwargs):
if self._client:
try:
print("[redis] " + command + " " + " ".join(args))
r = await self._client.execute_command(command, *args, **kwargs)
return r
except Exception as e:
print(f"[redis] error: {e}")
return None
async def subscribe(self, *channels):
if self._client:
async with self._client.pubsub() as pubsub:
for channel in channels:
await pubsub.subscribe(channel)
self.pubsub_channels.append(channel)
async def unsubscribe(self, *channels):
if not self._client:
return
async with self._client.pubsub() as pubsub:
for channel in channels:
await pubsub.unsubscribe(channel)
self.pubsub_channels.remove(channel)
async def publish(self, channel, data):
if not self._client:
return
await self._client.publish(channel, data)
async def lrange(self, key, start, stop):
if self._client:
print(f"[redis] LRANGE {key} {start} {stop}")
return await self._client.lrange(key, start, stop)
async def mget(self, key, *keys):
if self._client:
print(f"[redis] MGET {key} {keys}")
return await self._client.mget(key, *keys)
redis = RedisCache()
__all__ = ["redis"]

5
services/schema.py Normal file
View File

@@ -0,0 +1,5 @@
from ariadne import QueryType, MutationType
query = QueryType()
mutation = MutationType()
resolvers = [query, mutation]