Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
5bb4553360
|
@ -15,62 +15,65 @@ from settings import JWT_AUTH_HEADER
|
|||
|
||||
|
||||
class _Authenticate:
|
||||
@classmethod
|
||||
async def verify(cls, token: str):
|
||||
"""
|
||||
Rules for a token to be valid.
|
||||
1. token format is legal &&
|
||||
token exists in redis database &&
|
||||
token is not expired
|
||||
2. token format is legal &&
|
||||
token exists in redis database &&
|
||||
token is expired &&
|
||||
token is of specified type
|
||||
"""
|
||||
try:
|
||||
payload = Token.decode(token)
|
||||
except ExpiredSignatureError:
|
||||
payload = Token.decode(token, verify_exp=False)
|
||||
if not await cls.exists(payload.user_id, token):
|
||||
raise InvalidToken("Login expired, please login again")
|
||||
if payload.device == "mobile": # noqa
|
||||
"we cat set mobile token to be valid forever"
|
||||
return payload
|
||||
except DecodeError as e:
|
||||
raise InvalidToken("token format error") from e
|
||||
else:
|
||||
if not await cls.exists(payload.user_id, token):
|
||||
raise InvalidToken("Login expired, please login again")
|
||||
return payload
|
||||
@classmethod
|
||||
async def verify(cls, token: str):
|
||||
"""
|
||||
Rules for a token to be valid.
|
||||
1. token format is legal &&
|
||||
token exists in redis database &&
|
||||
token is not expired
|
||||
2. token format is legal &&
|
||||
token exists in redis database &&
|
||||
token is expired &&
|
||||
token is of specified type
|
||||
"""
|
||||
try:
|
||||
payload = Token.decode(token)
|
||||
except ExpiredSignatureError:
|
||||
payload = Token.decode(token, verify_exp=False)
|
||||
if not await cls.exists(payload.user_id, token):
|
||||
raise InvalidToken("Login expired, please login again")
|
||||
if payload.device == "mobile": # noqa
|
||||
"we cat set mobile token to be valid forever"
|
||||
return payload
|
||||
except DecodeError as e:
|
||||
raise InvalidToken("token format error") from e
|
||||
else:
|
||||
if not await cls.exists(payload.user_id, token):
|
||||
raise InvalidToken("Login expired, please login again")
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def exists(cls, user_id, token):
|
||||
token = await redis.execute("GET", f"{user_id}-{token}")
|
||||
return token is not None
|
||||
@classmethod
|
||||
async def exists(cls, user_id, token):
|
||||
token = await redis.execute("GET", f"{user_id}-{token}")
|
||||
return token is not None
|
||||
|
||||
|
||||
class JWTAuthenticate(AuthenticationBackend):
|
||||
async def authenticate(
|
||||
self, request: HTTPConnection
|
||||
) -> Optional[Tuple[AuthCredentials, AuthUser]]:
|
||||
if JWT_AUTH_HEADER not in request.headers:
|
||||
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
|
||||
async def authenticate(
|
||||
self, request: HTTPConnection
|
||||
) -> Optional[Tuple[AuthCredentials, AuthUser]]:
|
||||
if JWT_AUTH_HEADER not in request.headers:
|
||||
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
|
||||
|
||||
token = request.headers[JWT_AUTH_HEADER]
|
||||
try:
|
||||
payload = await _Authenticate.verify(token)
|
||||
except Exception as exc:
|
||||
return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(user_id=None)
|
||||
token = request.headers[JWT_AUTH_HEADER]
|
||||
try:
|
||||
payload = await _Authenticate.verify(token)
|
||||
except Exception as exc:
|
||||
return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(user_id=None)
|
||||
|
||||
if payload is None:
|
||||
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
|
||||
|
||||
scopes = User.get_permission(user_id=payload.user_id)
|
||||
return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthUser(user_id=payload.user_id)
|
||||
scopes = User.get_permission(user_id=payload.user_id)
|
||||
return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthUser(user_id=payload.user_id)
|
||||
|
||||
|
||||
def login_required(func):
|
||||
@wraps(func)
|
||||
async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs):
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
if not auth.logged_in:
|
||||
return {"error" : auth.error_message or "Please login"}
|
||||
return await func(parent, info, *args, **kwargs)
|
||||
return wrap
|
||||
@wraps(func)
|
||||
async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs):
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
if not auth.logged_in:
|
||||
return {"error" : auth.error_message or "Please login"}
|
||||
return await func(parent, info, *args, **kwargs)
|
||||
return wrap
|
||||
|
|
|
@ -9,7 +9,7 @@ class Permission(BaseModel):
|
|||
|
||||
class AuthCredentials(BaseModel):
|
||||
user_id: Optional[int] = None
|
||||
scopes: Optional[set] = {}
|
||||
scopes: Optional[dict] = {}
|
||||
logged_in: bool = False
|
||||
error_message: str = ""
|
||||
|
||||
|
|
20
create_crt.sh
Executable file → Normal file
20
create_crt.sh
Executable file → Normal file
|
@ -1,10 +1,10 @@
|
|||
#!/bin/bash
|
||||
|
||||
openssl req -newkey rsa:4096 \
|
||||
-x509 \
|
||||
-sha256 \
|
||||
-days 3650 \
|
||||
-nodes \
|
||||
-out discours.crt \
|
||||
-keyout discours.key \
|
||||
-subj "/C=RU/ST=Moscow/L=Moscow/O=Discours/OU=Site/CN=test-api.discours.io"
|
||||
#!/bin/bash
|
||||
|
||||
openssl req -newkey rsa:4096 \
|
||||
-x509 \
|
||||
-sha256 \
|
||||
-days 3650 \
|
||||
-nodes \
|
||||
-out discours.crt \
|
||||
-keyout discours.key \
|
||||
-subj "/C=RU/ST=Moscow/L=Moscow/O=Discours/OU=Site/CN=test-api.discours.io"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from orm.rbac import Operation, Permission, Role
|
||||
from orm.rbac import Organization, Operation, Resource, Permission, Role
|
||||
from orm.user import User
|
||||
from orm.message import Message
|
||||
from orm.shout import Shout
|
||||
|
@ -7,3 +7,5 @@ from orm.base import Base, engine
|
|||
__all__ = ["User", "Role", "Operation", "Permission", "Message", "Shout"]
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
Operation.init_table()
|
||||
Resource.init_table()
|
||||
|
|
94
orm/rbac.py
94
orm/rbac.py
|
@ -3,63 +3,85 @@ import warnings
|
|||
from typing import Type
|
||||
|
||||
from sqlalchemy import String, Column, ForeignKey, types, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from orm.base import Base, REGISTRY, engine
|
||||
from orm.base import Base, REGISTRY, engine, local_session
|
||||
|
||||
|
||||
class ClassType(types.TypeDecorator):
|
||||
impl = types.String
|
||||
impl = types.String
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return NotImplemented
|
||||
@property
|
||||
def python_type(self):
|
||||
return NotImplemented
|
||||
|
||||
def process_literal_param(self, value, dialect):
|
||||
return NotImplemented
|
||||
def process_literal_param(self, value, dialect):
|
||||
return NotImplemented
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value.__name__ if isinstance(value, type) else str(value)
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value.__name__ if isinstance(value, type) else str(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
class_ = REGISTRY.get(value)
|
||||
if class_ is None:
|
||||
warnings.warn(f"Can't find class <{value}>,find it yourself 😊", stacklevel=2)
|
||||
return class_
|
||||
def process_result_value(self, value, dialect):
|
||||
class_ = REGISTRY.get(value)
|
||||
if class_ is None:
|
||||
warnings.warn(f"Can't find class <{value}>,find it yourself 😊", stacklevel=2)
|
||||
return class_
|
||||
|
||||
class Organization(Base):
|
||||
__tablename__ = 'organization'
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Organization Name")
|
||||
|
||||
class Role(Base):
|
||||
__tablename__ = 'role'
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Role Name")
|
||||
__tablename__ = 'role'
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Role Name")
|
||||
org_id: int = Column(ForeignKey("organization.id", ondelete="CASCADE"), nullable=False, comment="Organization")
|
||||
|
||||
permissions = relationship("Permission")
|
||||
|
||||
class Operation(Base):
|
||||
__tablename__ = 'operation'
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Operation Name")
|
||||
__tablename__ = 'operation'
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Operation Name")
|
||||
|
||||
@staticmethod
|
||||
def init_table():
|
||||
with local_session() as session:
|
||||
edit_op = session.query(Operation).filter(Operation.name == "edit").first()
|
||||
if not edit_op:
|
||||
edit_op = Operation.create(name = "edit")
|
||||
Operation.edit_id = edit_op.id
|
||||
|
||||
|
||||
class Resource(Base):
|
||||
__tablename__ = "resource"
|
||||
resource_class: Type[Base] = Column(ClassType, nullable=False, unique=True, comment="Resource class")
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Resource name")
|
||||
__tablename__ = "resource"
|
||||
resource_class: Type[Base] = Column(ClassType, nullable=False, unique=True, comment="Resource class")
|
||||
name: str = Column(String, nullable=False, unique=True, comment="Resource name")
|
||||
|
||||
@staticmethod
|
||||
def init_table():
|
||||
with local_session() as session:
|
||||
shout_res = session.query(Resource).filter(Resource.name == "shout").first()
|
||||
if not shout_res:
|
||||
shout_res = Resource.create(name = "shout", resource_class = "shout")
|
||||
Resource.shout_id = shout_res.id
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
__tablename__ = "permission"
|
||||
__table_args__ = (UniqueConstraint("role_id", "operation_id", "resource_id"), {"extend_existing": True})
|
||||
__tablename__ = "permission"
|
||||
__table_args__ = (UniqueConstraint("role_id", "operation_id", "resource_id"), {"extend_existing": True})
|
||||
|
||||
role_id: int = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role")
|
||||
operation_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Operation")
|
||||
resource_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Resource")
|
||||
role_id: int = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role")
|
||||
operation_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Operation")
|
||||
resource_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Resource")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Base.metadata.create_all(engine)
|
||||
ops = [
|
||||
Permission(role_id=1, operation_id=1, resource_id=1),
|
||||
Permission(role_id=1, operation_id=2, resource_id=1),
|
||||
Permission(role_id=1, operation_id=3, resource_id=1),
|
||||
Permission(role_id=1, operation_id=4, resource_id=1),
|
||||
Permission(role_id=2, operation_id=4, resource_id=1)
|
||||
]
|
||||
global_session.add_all(ops)
|
||||
global_session.commit()
|
||||
Base.metadata.create_all(engine)
|
||||
ops = [
|
||||
Permission(role_id=1, operation_id=1, resource_id=1),
|
||||
Permission(role_id=1, operation_id=2, resource_id=1),
|
||||
Permission(role_id=1, operation_id=3, resource_id=1),
|
||||
Permission(role_id=1, operation_id=4, resource_id=1),
|
||||
Permission(role_id=2, operation_id=4, resource_id=1)
|
||||
]
|
||||
global_session.add_all(ops)
|
||||
global_session.commit()
|
||||
|
|
|
@ -12,7 +12,7 @@ class Shout(Base):
|
|||
id = None
|
||||
|
||||
slug: str = Column(String, primary_key=True)
|
||||
org: str = Column(String, nullable=False)
|
||||
org_id: str = Column(ForeignKey("organization.id"), nullable=False)
|
||||
author_id: str = Column(ForeignKey("user.id"), nullable=False, comment="Author")
|
||||
body: str = Column(String, nullable=False, comment="Body")
|
||||
createdAt: str = Column(DateTime, nullable=False, default = datetime.now, comment="Created at")
|
||||
|
|
27
orm/user.py
27
orm/user.py
|
@ -1,28 +1,41 @@
|
|||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey #, relationship
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from orm import Permission
|
||||
from orm.base import Base, local_session
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
__tablename__ = 'user_role'
|
||||
|
||||
id = None
|
||||
user_id: int = Column(ForeignKey("user.id"), primary_key = True)
|
||||
role_id: int = Column(ForeignKey("role.id"), primary_key = True)
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
email: str = Column(String, nullable=False)
|
||||
email: str = Column(String, unique=True, nullable=False)
|
||||
username: str = Column(String, nullable=False, comment="Name")
|
||||
password: str = Column(String, nullable=True, comment="Password")
|
||||
|
||||
role_id: list = Column(ForeignKey("role.id"), nullable=True, comment="Role")
|
||||
# roles = relationship("Role") TODO: one to many, see schema.graphql
|
||||
oauth_id: str = Column(String, nullable=True)
|
||||
|
||||
roles = relationship("Role", secondary=UserRole.__table__)
|
||||
|
||||
@classmethod
|
||||
def get_permission(cls, user_id):
|
||||
scope = {}
|
||||
with local_session() as session:
|
||||
perms: List[Permission] = session.query(Permission).join(User, User.role_id == Permission.role_id).filter(
|
||||
User.id == user_id).all()
|
||||
return {f"{p.operation_id}-{p.resource_id}" for p in perms}
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
for role in user.roles:
|
||||
for p in role.permissions:
|
||||
if not p.resource_id in scope:
|
||||
scope[p.resource_id] = set()
|
||||
scope[p.resource_id].add(p.operation_id)
|
||||
return scope
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,50 +1,49 @@
|
|||
from typing import Optional
|
||||
|
||||
import aioredis
|
||||
# from aioredis import ConnectionsPool
|
||||
|
||||
from settings import REDIS_URL
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self, uri=REDIS_URL):
|
||||
self._uri: str = uri
|
||||
self._instance = None
|
||||
|
||||
async def connect(self):
|
||||
if self._instance is not None:
|
||||
return
|
||||
self._instance = await aioredis.from_url(self._uri)# .create_pool(self._uri)
|
||||
|
||||
async def disconnect(self):
|
||||
if self._instance is None:
|
||||
return
|
||||
self._instance.close()
|
||||
await self._instance.wait_closed()
|
||||
self._instance = None
|
||||
|
||||
async def execute(self, command, *args, **kwargs):
|
||||
return await self._instance.execute(command, *args, **kwargs, encoding="UTF-8")
|
||||
|
||||
|
||||
async def test():
|
||||
redis = Redis()
|
||||
from datetime import datetime
|
||||
|
||||
await redis.connect()
|
||||
await redis.execute("SET", "1-KEY1", 1)
|
||||
await redis.execute("SET", "1-KEY2", 1)
|
||||
await redis.execute("SET", "1-KEY3", 1)
|
||||
await redis.execute("SET", "1-KEY4", 1)
|
||||
await redis.execute("EXPIREAT", "1-KEY4", int(datetime.utcnow().timestamp()))
|
||||
v = await redis.execute("KEYS", "1-*")
|
||||
print(v)
|
||||
await redis.execute("DEL", *v)
|
||||
v = await redis.execute("KEYS", "1-*")
|
||||
print(v)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(test())
|
||||
from typing import Optional
|
||||
|
||||
import aioredis
|
||||
|
||||
from settings import REDIS_URL
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self, uri=REDIS_URL):
|
||||
self._uri: str = uri
|
||||
self._instance = None
|
||||
|
||||
async def connect(self):
|
||||
if self._instance is not None:
|
||||
return
|
||||
self._instance = aioredis.from_url(self._uri, encoding="utf-8")
|
||||
|
||||
async def disconnect(self):
|
||||
if self._instance is None:
|
||||
return
|
||||
self._instance.close()
|
||||
await self._instance.wait_closed()
|
||||
self._instance = None
|
||||
|
||||
async def execute(self, command, *args, **kwargs):
|
||||
return await self._instance.execute_command(command, *args, **kwargs)
|
||||
|
||||
|
||||
async def test():
|
||||
redis = Redis()
|
||||
from datetime import datetime
|
||||
|
||||
await redis.connect()
|
||||
await redis.execute("SET", "1-KEY1", 1)
|
||||
await redis.execute("SET", "1-KEY2", 1)
|
||||
await redis.execute("SET", "1-KEY3", 1)
|
||||
await redis.execute("SET", "1-KEY4", 1)
|
||||
await redis.execute("EXPIREAT", "1-KEY4", int(datetime.utcnow().timestamp()))
|
||||
v = await redis.execute("KEYS", "1-*")
|
||||
print(v)
|
||||
await redis.execute("DEL", *v)
|
||||
v = await redis.execute("KEYS", "1-*")
|
||||
print(v)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(test())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from orm import Shout, User
|
||||
from orm import Shout, User, Organization, Resource
|
||||
from orm.base import local_session
|
||||
|
||||
from resolvers.base import mutation, query
|
||||
|
@ -15,10 +15,10 @@ class GitTask:
|
|||
|
||||
queue = asyncio.Queue()
|
||||
|
||||
def __init__(self, input, username, user_email, comment):
|
||||
def __init__(self, input, org, username, user_email, comment):
|
||||
self.slug = input["slug"];
|
||||
self.org = input["org"];
|
||||
self.shout_body = input["body"];
|
||||
self.org = org;
|
||||
self.username = username;
|
||||
self.user_email = user_email;
|
||||
self.comment = comment;
|
||||
|
@ -84,12 +84,19 @@ async def create_shout(_, info, input):
|
|||
auth = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
||||
org_id = org = input["org_id"]
|
||||
with local_session() as session:
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
org = session.query(Organization).filter(Organization.id == org_id).first()
|
||||
|
||||
if not org:
|
||||
return {
|
||||
"error" : "invalid organization"
|
||||
}
|
||||
|
||||
new_shout = Shout.create(
|
||||
slug = input["slug"],
|
||||
org = input["org"],
|
||||
org_id = org_id,
|
||||
author_id = user_id,
|
||||
body = input["body"],
|
||||
replyTo = input.get("replyTo"),
|
||||
|
@ -100,6 +107,7 @@ async def create_shout(_, info, input):
|
|||
|
||||
task = GitTask(
|
||||
input,
|
||||
org.name,
|
||||
user.username,
|
||||
user.email,
|
||||
"new shout %s" % (new_shout.slug)
|
||||
|
@ -109,5 +117,51 @@ async def create_shout(_, info, input):
|
|||
"shout" : new_shout
|
||||
}
|
||||
|
||||
@mutation.field("updateShout")
|
||||
@login_required
|
||||
async def update_shout(_, info, input):
|
||||
auth = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
||||
slug = input["slug"]
|
||||
org_id = org = input["org_id"]
|
||||
with local_session() as session:
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
shout = session.query(Shout).filter(Shout.slug == slug).first()
|
||||
org = session.query(Organization).filter(Organization.id == org_id).first()
|
||||
|
||||
if not shout:
|
||||
return {
|
||||
"error" : "shout not found"
|
||||
}
|
||||
|
||||
if shout.author_id != user_id:
|
||||
scopes = auth.scopes
|
||||
print(scopes)
|
||||
if not Resource.shout_id in scopes:
|
||||
return {
|
||||
"error" : "access denied"
|
||||
}
|
||||
|
||||
shout.body = input["body"],
|
||||
shout.replyTo = input.get("replyTo"),
|
||||
shout.versionOf = input.get("versionOf"),
|
||||
shout.tags = input.get("tags"),
|
||||
shout.topics = input.get("topics")
|
||||
|
||||
with local_session() as session:
|
||||
session.commit()
|
||||
|
||||
task = GitTask(
|
||||
input,
|
||||
org.name,
|
||||
user.username,
|
||||
user.email,
|
||||
"update shout %s" % (shout.slug)
|
||||
)
|
||||
|
||||
return {
|
||||
"shout" : shout
|
||||
}
|
||||
|
||||
# TODO: paginate, get, update, delete
|
||||
|
|
|
@ -23,7 +23,7 @@ type MessageResult {
|
|||
}
|
||||
|
||||
input ShoutInput {
|
||||
org: String!
|
||||
org_id: Int!
|
||||
slug: String!
|
||||
body: String!
|
||||
replyTo: String # another shout
|
||||
|
@ -61,10 +61,11 @@ type Mutation {
|
|||
# invalidateTokenById(id: Int!): Boolean!
|
||||
# requestEmailConfirmation: User!
|
||||
# requestPasswordReset(email: String!): Boolean!
|
||||
registerUser(email: String!, password: String!): AuthResult!
|
||||
registerUser(email: String!, password: String!): AuthResult!
|
||||
|
||||
# shout
|
||||
createShout(input: ShoutInput!): ShoutResult!
|
||||
updateShout(input: ShoutInput!): ShoutResult!
|
||||
deleteShout(slug: String!): Result!
|
||||
rateShout(slug: String!, value: Int!): Result!
|
||||
|
||||
|
@ -151,7 +152,7 @@ type Message {
|
|||
|
||||
# is publication
|
||||
type Shout {
|
||||
org: String!
|
||||
org_id: Int!
|
||||
slug: String!
|
||||
author: Int!
|
||||
body: String!
|
||||
|
|
Loading…
Reference in New Issue
Block a user