Feature/notifications (#77)
feature - notifications Co-authored-by: Igor Lobanov <igor.lobanov@onetwotrip.com>
This commit is contained in:
parent
702219769a
commit
889f802429
|
@ -36,7 +36,7 @@ class JWTCodec:
|
|||
issuer="discours"
|
||||
)
|
||||
r = TokenPayload(**payload)
|
||||
print('[auth.jwtcodec] debug token %r' % r)
|
||||
# print('[auth.jwtcodec] debug token %r' % r)
|
||||
return r
|
||||
except jwt.InvalidIssuedAtError:
|
||||
print('[auth.jwtcodec] invalid issued at: %r' % payload)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from ariadne import MutationType, QueryType, SubscriptionType, ScalarType
|
||||
|
||||
from ariadne import MutationType, QueryType, ScalarType
|
||||
|
||||
datetime_scalar = ScalarType("DateTime")
|
||||
|
||||
|
@ -11,5 +10,4 @@ def serialize_datetime(value):
|
|||
|
||||
query = QueryType()
|
||||
mutation = MutationType()
|
||||
subscription = SubscriptionType()
|
||||
resolvers = [query, mutation, subscription, datetime_scalar]
|
||||
resolvers = [query, mutation, datetime_scalar]
|
||||
|
|
36
main.py
36
main.py
|
@ -18,20 +18,18 @@ from base.resolvers import resolvers
|
|||
from resolvers.auth import confirm_email_handler
|
||||
from resolvers.upload import upload_handler
|
||||
from services.main import storages_init
|
||||
from services.notifications.notification_service import notification_service
|
||||
from services.stat.viewed import ViewedStorage
|
||||
from services.zine.gittask import GitTask
|
||||
from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN
|
||||
# from sse.transport import GraphQLSSEHandler
|
||||
from services.inbox.presence import on_connect, on_disconnect
|
||||
# from services.inbox.sse import sse_messages
|
||||
from ariadne.asgi.handlers import GraphQLTransportWSHandler
|
||||
# from services.zine.gittask import GitTask
|
||||
from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN, SESSION_SECRET_KEY
|
||||
from services.notifications.sse import sse_subscribe_handler
|
||||
|
||||
import_module("resolvers")
|
||||
schema = make_executable_schema(load_schema_from_path("schema.graphql"), resolvers) # type: ignore
|
||||
|
||||
middleware = [
|
||||
Middleware(AuthenticationMiddleware, backend=JWTAuthenticate()),
|
||||
Middleware(SessionMiddleware, secret_key="!secret"),
|
||||
Middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY),
|
||||
]
|
||||
|
||||
|
||||
|
@ -41,8 +39,11 @@ async def start_up():
|
|||
await storages_init()
|
||||
views_stat_task = asyncio.create_task(ViewedStorage().worker())
|
||||
print(views_stat_task)
|
||||
git_task = asyncio.create_task(GitTask.git_task_worker())
|
||||
print(git_task)
|
||||
# git_task = asyncio.create_task(GitTask.git_task_worker())
|
||||
# print(git_task)
|
||||
notification_service_task = asyncio.create_task(notification_service.worker())
|
||||
print(notification_service_task)
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
sentry_sdk.init(SENTRY_DSN)
|
||||
|
@ -71,7 +72,8 @@ routes = [
|
|||
Route("/oauth/{provider}", endpoint=oauth_login),
|
||||
Route("/oauth-authorize", endpoint=oauth_authorize),
|
||||
Route("/confirm/{token}", endpoint=confirm_email_handler),
|
||||
Route("/upload", endpoint=upload_handler, methods=['POST'])
|
||||
Route("/upload", endpoint=upload_handler, methods=['POST']),
|
||||
Route("/subscribe/{user_id}", endpoint=sse_subscribe_handler),
|
||||
]
|
||||
|
||||
app = Starlette(
|
||||
|
@ -83,14 +85,10 @@ app = Starlette(
|
|||
)
|
||||
app.mount("/", GraphQL(
|
||||
schema,
|
||||
debug=True,
|
||||
websocket_handler=GraphQLTransportWSHandler(
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect
|
||||
)
|
||||
debug=True
|
||||
))
|
||||
|
||||
dev_app = app = Starlette(
|
||||
dev_app = Starlette(
|
||||
debug=True,
|
||||
on_startup=[dev_start_up],
|
||||
on_shutdown=[shutdown],
|
||||
|
@ -99,9 +97,5 @@ dev_app = app = Starlette(
|
|||
)
|
||||
dev_app.mount("/", GraphQL(
|
||||
schema,
|
||||
debug=True,
|
||||
websocket_handler=GraphQLTransportWSHandler(
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect
|
||||
)
|
||||
debug=True
|
||||
))
|
||||
|
|
|
@ -7,7 +7,18 @@ from orm.shout import Shout
|
|||
from orm.topic import Topic, TopicFollower
|
||||
from orm.user import User, UserRating
|
||||
|
||||
# NOTE: keep orm module isolated
|
||||
|
||||
def init_tables():
|
||||
Base.metadata.create_all(engine)
|
||||
Operation.init_table()
|
||||
Resource.init_table()
|
||||
User.init_table()
|
||||
Community.init_table()
|
||||
Role.init_table()
|
||||
UserRating.init_table()
|
||||
Shout.init_table()
|
||||
print("[orm] tables initialized")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
|
@ -21,16 +32,5 @@ __all__ = [
|
|||
"Notification",
|
||||
"Reaction",
|
||||
"UserRating",
|
||||
"init_tables"
|
||||
]
|
||||
|
||||
|
||||
def init_tables():
|
||||
Base.metadata.create_all(engine)
|
||||
Operation.init_table()
|
||||
Resource.init_table()
|
||||
User.init_table()
|
||||
Community.init_table()
|
||||
Role.init_table()
|
||||
UserRating.init_table()
|
||||
Shout.init_table()
|
||||
print("[orm] tables initialized")
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from datetime import datetime
|
||||
from sqlalchemy import Column, Enum, JSON, ForeignKey, DateTime, Boolean, Integer
|
||||
from sqlalchemy import Column, Enum, ForeignKey, DateTime, Boolean, Integer
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from base.orm import Base
|
||||
from enum import Enum as Enumeration
|
||||
|
||||
|
@ -18,5 +20,5 @@ class Notification(Base):
|
|||
createdAt = Column(DateTime, nullable=False, default=datetime.now, index=True)
|
||||
seen = Column(Boolean, nullable=False, default=False, index=True)
|
||||
type = Column(Enum(NotificationType), nullable=False)
|
||||
data = Column(JSON, nullable=True)
|
||||
data = Column(JSONB, nullable=True)
|
||||
occurrences = Column(Integer, default=1)
|
||||
|
|
|
@ -11,14 +11,12 @@ gql~=3.4.0
|
|||
uvicorn>=0.18.3
|
||||
pydantic>=1.10.2
|
||||
passlib~=1.7.4
|
||||
itsdangerous
|
||||
authlib>=1.1.0
|
||||
httpx>=0.23.0
|
||||
psycopg2-binary
|
||||
transliterate~=1.10.2
|
||||
requests~=2.28.1
|
||||
bcrypt>=4.0.0
|
||||
websockets
|
||||
bson~=0.5.10
|
||||
flake8
|
||||
DateTime~=4.7
|
||||
|
@ -38,3 +36,4 @@ python-multipart~=0.0.6
|
|||
alembic==1.11.3
|
||||
Mako==1.2.4
|
||||
MarkupSafe==2.1.3
|
||||
sse-starlette=1.6.5
|
||||
|
|
0
resetdb.sh
Normal file → Executable file
0
resetdb.sh
Normal file → Executable file
|
@ -55,7 +55,6 @@ from resolvers.inbox.messages import (
|
|||
create_message,
|
||||
delete_message,
|
||||
update_message,
|
||||
message_generator,
|
||||
mark_as_read
|
||||
)
|
||||
from resolvers.inbox.load import (
|
||||
|
@ -65,56 +64,4 @@ from resolvers.inbox.load import (
|
|||
)
|
||||
from resolvers.inbox.search import search_recipients
|
||||
|
||||
__all__ = [
|
||||
# auth
|
||||
"login",
|
||||
"register_by_email",
|
||||
"is_email_used",
|
||||
"confirm_email",
|
||||
"auth_send_link",
|
||||
"sign_out",
|
||||
"get_current_user",
|
||||
# zine.profile
|
||||
"load_authors_by",
|
||||
"rate_user",
|
||||
"update_profile",
|
||||
"get_authors_all",
|
||||
# zine.load
|
||||
"load_shout",
|
||||
"load_shouts_by",
|
||||
# zine.following
|
||||
"follow",
|
||||
"unfollow",
|
||||
# create
|
||||
"create_shout",
|
||||
"update_shout",
|
||||
"delete_shout",
|
||||
"markdown_body",
|
||||
# zine.topics
|
||||
"topics_all",
|
||||
"topics_by_community",
|
||||
"topics_by_author",
|
||||
"topic_follow",
|
||||
"topic_unfollow",
|
||||
"get_topic",
|
||||
# zine.reactions
|
||||
"reactions_follow",
|
||||
"reactions_unfollow",
|
||||
"create_reaction",
|
||||
"update_reaction",
|
||||
"delete_reaction",
|
||||
"load_reactions_by",
|
||||
# inbox
|
||||
"load_chats",
|
||||
"load_messages_by",
|
||||
"create_chat",
|
||||
"delete_chat",
|
||||
"update_chat",
|
||||
"create_message",
|
||||
"delete_message",
|
||||
"update_message",
|
||||
"message_generator",
|
||||
"mark_as_read",
|
||||
"load_recipients",
|
||||
"search_recipients"
|
||||
]
|
||||
from resolvers.notifications import load_notifications
|
||||
|
|
|
@ -6,7 +6,7 @@ from graphql.type import GraphQLResolveInfo
|
|||
from auth.authenticate import login_required
|
||||
from auth.credentials import AuthCredentials
|
||||
from base.redis import redis
|
||||
from base.resolvers import mutation, subscription
|
||||
from base.resolvers import mutation
|
||||
from services.following import FollowingManager, FollowingResult, Following
|
||||
from validations.inbox import Message
|
||||
|
||||
|
@ -140,40 +140,3 @@ async def mark_as_read(_, info, chat_id: str, messages: [int]):
|
|||
return {
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
@subscription.source("newMessage")
|
||||
async def message_generator(_, info: GraphQLResolveInfo):
|
||||
print(f"[resolvers.messages] generator {info}")
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
try:
|
||||
user_following_chats = await redis.execute("GET", f"chats_by_user/{user_id}")
|
||||
if user_following_chats:
|
||||
user_following_chats = list(json.loads(user_following_chats)) # chat ids
|
||||
else:
|
||||
user_following_chats = []
|
||||
tasks = []
|
||||
updated = {}
|
||||
for chat_id in user_following_chats:
|
||||
chat = await redis.execute("GET", f"chats/{chat_id}")
|
||||
updated[chat_id] = chat['updatedAt']
|
||||
user_following_chats_sorted = sorted(user_following_chats, key=lambda x: updated[x], reverse=True)
|
||||
|
||||
for chat_id in user_following_chats_sorted:
|
||||
following_chat = Following('chat', chat_id)
|
||||
await FollowingManager.register('chat', following_chat)
|
||||
chat_task = following_chat.queue.get()
|
||||
tasks.append(chat_task)
|
||||
|
||||
while True:
|
||||
msg = await asyncio.gather(*tasks)
|
||||
yield msg
|
||||
finally:
|
||||
await FollowingManager.remove('chat', following_chat)
|
||||
|
||||
|
||||
@subscription.field("newMessage")
|
||||
@login_required
|
||||
async def message_resolver(message: Message, info: Any):
|
||||
return message
|
||||
|
|
84
resolvers/notifications.py
Normal file
84
resolvers/notifications.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from sqlalchemy import select, desc, and_, update
|
||||
|
||||
from auth.credentials import AuthCredentials
|
||||
from base.resolvers import query, mutation
|
||||
from auth.authenticate import login_required
|
||||
from base.orm import local_session
|
||||
from orm import Notification
|
||||
|
||||
|
||||
@query.field("loadNotifications")
|
||||
@login_required
|
||||
async def load_notifications(_, info, params=None):
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
||||
limit = params.get('limit', 50)
|
||||
offset = params.get('offset', 0)
|
||||
|
||||
q = select(Notification).where(
|
||||
Notification.user == user_id
|
||||
).order_by(desc(Notification.createdAt)).limit(limit).offset(offset)
|
||||
|
||||
with local_session() as session:
|
||||
total_count = session.query(Notification).where(
|
||||
Notification.user == user_id
|
||||
).count()
|
||||
|
||||
total_unread_count = session.query(Notification).where(
|
||||
and_(
|
||||
Notification.user == user_id,
|
||||
Notification.seen is False
|
||||
)
|
||||
).count()
|
||||
|
||||
notifications = session.execute(q).fetchall()
|
||||
|
||||
return {
|
||||
"notifications": notifications,
|
||||
"totalCount": total_count,
|
||||
"totalUnreadCount": total_unread_count
|
||||
}
|
||||
|
||||
|
||||
@mutation.field("markNotificationAsRead")
|
||||
@login_required
|
||||
async def mark_notification_as_read(_, info, notification_id: int):
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
||||
with local_session() as session:
|
||||
notification = session.query(Notification).where(
|
||||
and_(Notification.id == notification_id, Notification.user == user_id)
|
||||
).one()
|
||||
notification.seen = True
|
||||
session.commit()
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@mutation.field("markAllNotificationsAsRead")
|
||||
@login_required
|
||||
async def mark_all_notifications_as_read(_, info):
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
||||
statement = update(Notification).where(
|
||||
and_(
|
||||
Notification.user == user_id,
|
||||
Notification.seen == False
|
||||
)
|
||||
).values(seen=True)
|
||||
|
||||
with local_session() as session:
|
||||
try:
|
||||
session.execute(statement)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
print(f"[mark_all_notifications_as_read] error: {str(e)}")
|
||||
|
||||
return {}
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from base.orm import local_session
|
||||
from base.resolvers import mutation, subscription
|
||||
from base.resolvers import mutation
|
||||
from auth.authenticate import login_required
|
||||
from auth.credentials import AuthCredentials
|
||||
# from resolvers.community import community_follow, community_unfollow
|
||||
|
@ -69,79 +69,3 @@ async def unfollow(_, info, what, slug):
|
|||
return {"error": str(e)}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
# by author and by topic
|
||||
@subscription.source("newShout")
|
||||
@login_required
|
||||
async def shout_generator(_, info: GraphQLResolveInfo):
|
||||
print(f"[resolvers.zine] shouts generator {info}")
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
try:
|
||||
tasks = []
|
||||
|
||||
with local_session() as session:
|
||||
|
||||
# notify new shout by followed authors
|
||||
following_topics = session.query(TopicFollower).where(TopicFollower.follower == user_id).all()
|
||||
|
||||
for topic_id in following_topics:
|
||||
following_topic = Following('topic', topic_id)
|
||||
await FollowingManager.register('topic', following_topic)
|
||||
following_topic_task = following_topic.queue.get()
|
||||
tasks.append(following_topic_task)
|
||||
|
||||
# by followed topics
|
||||
following_authors = session.query(AuthorFollower).where(
|
||||
AuthorFollower.follower == user_id).all()
|
||||
|
||||
for author_id in following_authors:
|
||||
following_author = Following('author', author_id)
|
||||
await FollowingManager.register('author', following_author)
|
||||
following_author_task = following_author.queue.get()
|
||||
tasks.append(following_author_task)
|
||||
|
||||
# TODO: use communities
|
||||
# by followed communities
|
||||
# following_communities = session.query(CommunityFollower).where(
|
||||
# CommunityFollower.follower == user_id).all()
|
||||
|
||||
# for community_id in following_communities:
|
||||
# following_community = Following('community', author_id)
|
||||
# await FollowingManager.register('community', following_community)
|
||||
# following_community_task = following_community.queue.get()
|
||||
# tasks.append(following_community_task)
|
||||
|
||||
while True:
|
||||
shout = await asyncio.gather(*tasks)
|
||||
yield shout
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@subscription.source("newReaction")
|
||||
@login_required
|
||||
async def reaction_generator(_, info):
|
||||
print(f"[resolvers.zine] reactions generator {info}")
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
try:
|
||||
with local_session() as session:
|
||||
followings = session.query(ShoutReactionsFollower.shout).where(
|
||||
ShoutReactionsFollower.follower == user_id).unique()
|
||||
|
||||
# notify new reaction
|
||||
|
||||
tasks = []
|
||||
for shout_id in followings:
|
||||
following_shout = Following('shout', shout_id)
|
||||
await FollowingManager.register('shout', following_shout)
|
||||
following_author_task = following_shout.queue.get()
|
||||
tasks.append(following_author_task)
|
||||
|
||||
while True:
|
||||
reaction = await asyncio.gather(*tasks)
|
||||
yield reaction
|
||||
finally:
|
||||
pass
|
||||
|
|
|
@ -183,6 +183,7 @@ async def load_shouts_by(_, info, options):
|
|||
|
||||
|
||||
@query.field("loadDrafts")
|
||||
@login_required
|
||||
async def get_drafts(_, info):
|
||||
auth: AuthCredentials = info.context["request"].auth
|
||||
user_id = auth.user_id
|
||||
|
|
|
@ -10,6 +10,7 @@ from base.resolvers import mutation, query
|
|||
from orm.reaction import Reaction, ReactionKind
|
||||
from orm.shout import Shout, ShoutReactionsFollower
|
||||
from orm.user import User
|
||||
from services.notifications.notification_service import notification_service
|
||||
|
||||
|
||||
def add_reaction_stat_columns(q):
|
||||
|
@ -198,29 +199,32 @@ async def create_reaction(_, info, reaction):
|
|||
|
||||
r = Reaction.create(**reaction)
|
||||
|
||||
# Proposal accepting logix
|
||||
if r.replyTo is not None and \
|
||||
r.kind == ReactionKind.ACCEPT and \
|
||||
auth.user_id in shout.dict()['authors']:
|
||||
replied_reaction = session.query(Reaction).where(Reaction.id == r.replyTo).first()
|
||||
if replied_reaction and replied_reaction.kind == ReactionKind.PROPOSE:
|
||||
if replied_reaction.range:
|
||||
old_body = shout.body
|
||||
start, end = replied_reaction.range.split(':')
|
||||
start = int(start)
|
||||
end = int(end)
|
||||
new_body = old_body[:start] + replied_reaction.body + old_body[end:]
|
||||
shout.body = new_body
|
||||
# TODO: update git version control
|
||||
# # Proposal accepting logix
|
||||
# FIXME: will break if there will be 2 proposals, will break if shout will be changed
|
||||
# if r.replyTo is not None and \
|
||||
# r.kind == ReactionKind.ACCEPT and \
|
||||
# auth.user_id in shout.dict()['authors']:
|
||||
# replied_reaction = session.query(Reaction).where(Reaction.id == r.replyTo).first()
|
||||
# if replied_reaction and replied_reaction.kind == ReactionKind.PROPOSE:
|
||||
# if replied_reaction.range:
|
||||
# old_body = shout.body
|
||||
# start, end = replied_reaction.range.split(':')
|
||||
# start = int(start)
|
||||
# end = int(end)
|
||||
# new_body = old_body[:start] + replied_reaction.body + old_body[end:]
|
||||
# shout.body = new_body
|
||||
# # TODO: update git version control
|
||||
|
||||
session.add(r)
|
||||
session.commit()
|
||||
|
||||
await notification_service.handle_new_reaction(r.id)
|
||||
|
||||
rdict = r.dict()
|
||||
rdict['shout'] = shout.dict()
|
||||
rdict['createdBy'] = author.dict()
|
||||
|
||||
# self-regulation mechanics
|
||||
|
||||
if check_to_hide(session, auth.user_id, r):
|
||||
set_hidden(session, r.shout)
|
||||
elif check_to_publish(session, auth.user_id, r):
|
||||
|
|
|
@ -179,7 +179,6 @@ type Mutation {
|
|||
|
||||
# user profile
|
||||
rateUser(slug: String!, value: Int!): Result!
|
||||
updateOnlineStatus: Result!
|
||||
updateProfile(profile: ProfileInput!): Result!
|
||||
|
||||
# topics
|
||||
|
@ -196,6 +195,9 @@ type Mutation {
|
|||
# following
|
||||
follow(what: FollowingEntity!, slug: String!): Result!
|
||||
unfollow(what: FollowingEntity!, slug: String!): Result!
|
||||
|
||||
markNotificationAsRead(notification_id: Int!): Result!
|
||||
markAllNotificationsAsRead: Result!
|
||||
}
|
||||
|
||||
input MessagesBy {
|
||||
|
@ -249,7 +251,17 @@ input ReactionBy {
|
|||
days: Int # before
|
||||
sort: String # how to sort, default createdAt
|
||||
}
|
||||
################################### Query
|
||||
|
||||
input NotificationsQueryParams {
|
||||
limit: Int
|
||||
offset: Int
|
||||
}
|
||||
|
||||
type NotificationsQueryResult {
|
||||
notifications: [Notification]!
|
||||
totalCount: Int!
|
||||
totalUnreadCount: Int!
|
||||
}
|
||||
|
||||
type Query {
|
||||
# inbox
|
||||
|
@ -286,14 +298,8 @@ type Query {
|
|||
topicsRandom(amount: Int): [Topic]!
|
||||
topicsByCommunity(community: String!): [Topic]!
|
||||
topicsByAuthor(author: String!): [Topic]!
|
||||
}
|
||||
|
||||
############################################ Subscription
|
||||
|
||||
type Subscription {
|
||||
newMessage: Message # new messages in inbox
|
||||
newShout: Shout # personal feed new shout
|
||||
newReaction: Reaction # new reactions to notify
|
||||
loadNotifications(params: NotificationsQueryParams!): NotificationsQueryResult!
|
||||
}
|
||||
|
||||
############################################ Entities
|
||||
|
|
|
@ -55,7 +55,7 @@ log_settings = {
|
|||
|
||||
local_headers = [
|
||||
("Access-Control-Allow-Methods", "GET, POST, OPTIONS, HEAD"),
|
||||
("Access-Control-Allow-Origin", "http://localhost:3000"),
|
||||
("Access-Control-Allow-Origin", "https://localhost:3000"),
|
||||
(
|
||||
"Access-Control-Allow-Headers",
|
||||
"DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization",
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
# from base.exceptions import Unauthorized
|
||||
from auth.tokenstorage import SessionToken
|
||||
from base.redis import redis
|
||||
|
||||
|
||||
async def set_online_status(user_id, status):
|
||||
if user_id:
|
||||
if status:
|
||||
await redis.execute("SADD", "users-online", user_id)
|
||||
else:
|
||||
await redis.execute("SREM", "users-online", user_id)
|
||||
|
||||
|
||||
async def on_connect(req, params):
|
||||
if not isinstance(params, dict):
|
||||
req.scope["connection_params"] = {}
|
||||
return
|
||||
token = params.get('token')
|
||||
if not token:
|
||||
# raise Unauthorized("Please login")
|
||||
return {
|
||||
"error": "Please login first"
|
||||
}
|
||||
else:
|
||||
payload = await SessionToken.verify(token)
|
||||
if payload and payload.user_id:
|
||||
req.scope["user_id"] = payload.user_id
|
||||
await set_online_status(payload.user_id, True)
|
||||
|
||||
|
||||
async def on_disconnect(req):
|
||||
user_id = req.scope.get("user_id")
|
||||
await set_online_status(user_id, False)
|
||||
|
||||
|
||||
# FIXME: not used yet
|
||||
def context_value(request):
|
||||
context = {}
|
||||
print(f"[inbox.presense] request debug: {request}")
|
||||
if request.scope["type"] == "websocket":
|
||||
# request is an instance of WebSocket
|
||||
context.update(request.scope["connection_params"])
|
||||
else:
|
||||
context["token"] = request.META.get("authorization")
|
||||
|
||||
return context
|
|
@ -1,22 +0,0 @@
|
|||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from graphql.type import GraphQLResolveInfo
|
||||
from resolvers.inbox.messages import message_generator
|
||||
# from base.exceptions import Unauthorized
|
||||
|
||||
# https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md
|
||||
|
||||
|
||||
async def sse_messages(request: Request):
|
||||
print(f'[SSE] request\n{request}\n')
|
||||
info = GraphQLResolveInfo()
|
||||
info.context['request'] = request.scope
|
||||
user_id = request.scope['user'].user_id
|
||||
if user_id:
|
||||
event_generator = await message_generator(None, info)
|
||||
return EventSourceResponse(event_generator)
|
||||
else:
|
||||
# raise Unauthorized("Please login")
|
||||
return {
|
||||
"error": "Please login first"
|
||||
}
|
137
services/notifications/notification_service.py
Normal file
137
services/notifications/notification_service.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from base.orm import local_session
|
||||
from orm import Reaction, Shout, Notification, User
|
||||
from orm.notification import NotificationType
|
||||
from orm.reaction import ReactionKind
|
||||
from services.notifications.sse import connection_manager
|
||||
|
||||
|
||||
def update_prev_notification(notification, user):
|
||||
notification_data = json.loads(notification.data)
|
||||
|
||||
notification_data["users"] = [
|
||||
user for user in notification_data["users"] if user['id'] != user.id
|
||||
]
|
||||
notification_data["users"].append({
|
||||
"id": user.id,
|
||||
"name": user.name
|
||||
})
|
||||
|
||||
notification.data = json.dumps(notification_data, ensure_ascii=False)
|
||||
notification.seen = False
|
||||
notification.occurrences = notification.occurrences + 1
|
||||
notification.createdAt = datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
class NewReactionNotificator:
|
||||
def __init__(self, reaction_id):
|
||||
self.reaction_id = reaction_id
|
||||
|
||||
async def run(self):
|
||||
with local_session() as session:
|
||||
reaction = session.query(Reaction).where(Reaction.id == self.reaction_id).one()
|
||||
shout = session.query(Shout).where(Shout.id == reaction.shout).one()
|
||||
user = session.query(User).where(User.id == reaction.createdBy).one()
|
||||
notify_user_ids = []
|
||||
|
||||
if reaction.kind == ReactionKind.COMMENT:
|
||||
parent_reaction = None
|
||||
if reaction.replyTo:
|
||||
parent_reaction = session.query(Reaction).where(Reaction.id == reaction.replyTo).one()
|
||||
if parent_reaction.createdBy != reaction.createdBy:
|
||||
prev_new_reply_notification = session.query(Notification).where(
|
||||
and_(
|
||||
Notification.user == shout.createdBy,
|
||||
Notification.type == NotificationType.NEW_REPLY,
|
||||
Notification.shout == shout.id,
|
||||
Notification.reaction == parent_reaction.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if prev_new_reply_notification:
|
||||
update_prev_notification(prev_new_reply_notification, user)
|
||||
else:
|
||||
reply_notification_data = json.dumps({
|
||||
"shout": {
|
||||
"title": shout.title
|
||||
},
|
||||
"users": [
|
||||
{"id": user.id, "name": user.name}
|
||||
]
|
||||
}, ensure_ascii=False)
|
||||
|
||||
reply_notification = Notification.create(**{
|
||||
"user": parent_reaction.createdBy,
|
||||
"type": NotificationType.NEW_REPLY.name,
|
||||
"shout": shout.id,
|
||||
"reaction": parent_reaction.id,
|
||||
"data": reply_notification_data
|
||||
})
|
||||
|
||||
session.add(reply_notification)
|
||||
|
||||
notify_user_ids.append(parent_reaction.createdBy)
|
||||
|
||||
if reaction.createdBy != shout.createdBy and (
|
||||
parent_reaction is None or parent_reaction.createdBy != shout.createdBy
|
||||
):
|
||||
prev_new_comment_notification = session.query(Notification).where(
|
||||
and_(
|
||||
Notification.user == shout.createdBy,
|
||||
Notification.type == NotificationType.NEW_COMMENT,
|
||||
Notification.shout == shout.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if prev_new_comment_notification:
|
||||
update_prev_notification(prev_new_comment_notification, user)
|
||||
else:
|
||||
notification_data_string = json.dumps({
|
||||
"shout": {
|
||||
"title": shout.title
|
||||
},
|
||||
"users": [
|
||||
{"id": user.id, "name": user.name}
|
||||
]
|
||||
}, ensure_ascii=False)
|
||||
|
||||
author_notification = Notification.create(**{
|
||||
"user": shout.createdBy,
|
||||
"type": NotificationType.NEW_COMMENT.name,
|
||||
"shout": shout.id,
|
||||
"data": notification_data_string
|
||||
})
|
||||
|
||||
session.add(author_notification)
|
||||
|
||||
notify_user_ids.append(shout.createdBy)
|
||||
|
||||
session.commit()
|
||||
|
||||
for user_id in notify_user_ids:
|
||||
await connection_manager.notify_user(user_id)
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self):
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def handle_new_reaction(self, reaction_id):
|
||||
notificator = NewReactionNotificator(reaction_id)
|
||||
await self._queue.put(notificator)
|
||||
|
||||
async def worker(self):
|
||||
while True:
|
||||
notificator = await self._queue.get()
|
||||
try:
|
||||
await notificator.run()
|
||||
except Exception as e:
|
||||
print(f'[NotificationService.worker] error: {str(e)}')
|
||||
|
||||
|
||||
notification_service = NotificationService()
|
72
services/notifications/sse.py
Normal file
72
services/notifications/sse.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
import asyncio
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.connections_by_user_id = {}
|
||||
|
||||
def add_connection(self, user_id, connection):
|
||||
if user_id not in self.connections_by_user_id:
|
||||
self.connections_by_user_id[user_id] = []
|
||||
self.connections_by_user_id[user_id].append(connection)
|
||||
|
||||
def remove_connection(self, user_id, connection):
|
||||
if user_id not in self.connections_by_user_id:
|
||||
return
|
||||
|
||||
self.connections_by_user_id[user_id].remove(connection)
|
||||
|
||||
if len(self.connections_by_user_id[user_id]) == 0:
|
||||
del self.connections_by_user_id[user_id]
|
||||
|
||||
async def notify_user(self, user_id):
|
||||
if user_id not in self.connections_by_user_id:
|
||||
return
|
||||
|
||||
for connection in self.connections_by_user_id[user_id]:
|
||||
data = {
|
||||
"type": "newNotifications"
|
||||
}
|
||||
data_string = json.dumps(data, ensure_ascii=False)
|
||||
await connection.put(data_string)
|
||||
|
||||
async def broadcast(self, data: str):
|
||||
for user_id in self.connections_by_user_id:
|
||||
for connection in self.connections_by_user_id[user_id]:
|
||||
await connection.put(data)
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self):
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def put(self, data: str):
|
||||
await self._queue.put(data)
|
||||
|
||||
async def listen(self):
|
||||
data = await self._queue.get()
|
||||
return data
|
||||
|
||||
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
async def sse_subscribe_handler(request: Request):
|
||||
user_id = int(request.path_params["user_id"])
|
||||
connection = Connection()
|
||||
connection_manager.add_connection(user_id, connection)
|
||||
|
||||
async def event_publisher():
|
||||
try:
|
||||
while True:
|
||||
data = await connection.listen()
|
||||
yield data
|
||||
except asyncio.CancelledError as e:
|
||||
connection_manager.remove_connection(user_id, connection)
|
||||
raise e
|
||||
|
||||
return EventSourceResponse(event_publisher())
|
|
@ -27,6 +27,7 @@ SHOUTS_REPO = "content"
|
|||
SESSION_TOKEN_HEADER = "Authorization"
|
||||
|
||||
SENTRY_DSN = environ.get("SENTRY_DSN")
|
||||
SESSION_SECRET_KEY = environ.get("SESSION_SECRET_KEY") or "!secret"
|
||||
|
||||
# for local development
|
||||
DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'
|
||||
|
|
43
test/test.json
Normal file
43
test/test.json
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user