refactored
This commit is contained in:
@@ -9,7 +9,7 @@ from orm.reaction import Reaction
|
||||
from storages.topics import TopicStorage
|
||||
from storages.users import UserStorage
|
||||
from storages.viewed import ViewedStorage
|
||||
from orm.base import Base, engine, local_session
|
||||
from base.orm import Base, engine, local_session
|
||||
|
||||
__all__ = ["User", "Role", "Operation", "Permission", \
|
||||
"Community", "Shout", "Topic", "TopicFollower", \
|
||||
|
56
orm/base.py
56
orm/base.py
@@ -1,56 +0,0 @@
|
||||
from typing import TypeVar, Any, Dict, Generic, Callable
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.schema import Table
|
||||
|
||||
from settings import DB_URL
|
||||
|
||||
if DB_URL.startswith('sqlite'):
|
||||
engine = create_engine(DB_URL)
|
||||
else:
|
||||
engine = create_engine(DB_URL, convert_unicode=True, echo=False, \
|
||||
pool_size=10, max_overflow=20)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
REGISTRY: Dict[str, type] = {}
|
||||
|
||||
def local_session():
|
||||
return Session(bind=engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(declarative_base()):
|
||||
__table__: Table
|
||||
__tablename__: str
|
||||
__new__: Callable
|
||||
__init__: Callable
|
||||
|
||||
__abstract__: bool = True
|
||||
__table_args__ = {"extend_existing": True}
|
||||
id: int = Column(Integer, primary_key=True)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
REGISTRY[cls.__name__] = cls
|
||||
|
||||
@classmethod
|
||||
def create(cls: Generic[T], **kwargs) -> Generic[T]:
|
||||
instance = cls(**kwargs)
|
||||
return instance.save()
|
||||
|
||||
def save(self) -> Generic[T]:
|
||||
with local_session() as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def update(self, input):
|
||||
column_names = self.__table__.columns.keys()
|
||||
for (name, value) in input.items():
|
||||
if name in column_names:
|
||||
setattr(self, name, value)
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
column_names = self.__table__.columns.keys()
|
||||
return {c: getattr(self, c) for c in column_names}
|
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, ForeignKey, DateTime
|
||||
from orm.base import Base, local_session
|
||||
from base.orm import Base, local_session
|
||||
|
||||
class CommunityFollower(Base):
|
||||
__tablename__ = 'community_followers'
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy import Column, String, JSON as JSONType
|
||||
from orm.base import Base
|
||||
from base.orm import Base
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = 'notification'
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from sqlalchemy import String, Column, ForeignKey, UniqueConstraint, TypeDecorator
|
||||
from sqlalchemy.orm import relationship
|
||||
from orm.base import Base, REGISTRY, engine, local_session
|
||||
from base.orm import Base, REGISTRY, engine, local_session
|
||||
from orm.community import Community
|
||||
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, ForeignKey, DateTime
|
||||
from orm.base import Base, local_session
|
||||
from base.orm import Base, local_session
|
||||
import enum
|
||||
from sqlalchemy import Enum
|
||||
from storages.viewed import ViewedStorage
|
||||
|
@@ -6,7 +6,7 @@ from orm.topic import Topic, ShoutTopic
|
||||
from orm.reaction import Reaction
|
||||
from storages.reactions import ReactionsStorage
|
||||
from storages.viewed import ViewedStorage
|
||||
from orm.base import Base
|
||||
from base.orm import Base
|
||||
|
||||
|
||||
class ShoutReactionsFollower(Base):
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, ForeignKey, DateTime, JSON as JSONType
|
||||
from orm.base import Base
|
||||
from base.orm import Base
|
||||
|
||||
class ShoutTopic(Base):
|
||||
__tablename__ = 'shout_topic'
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, Boolean, DateTime, JSON as JSONType
|
||||
from sqlalchemy.orm import relationship
|
||||
from orm.base import Base, local_session
|
||||
from base.orm import Base, local_session
|
||||
from orm.rbac import Role
|
||||
from storages.roles import RoleStorage
|
||||
|
||||
|
Reference in New Issue
Block a user