diff --git a/Cargo.lock b/Cargo.lock index c035b32..67e3a63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,7 +455,7 @@ dependencies = [ [[package]] name = "discoursio-presence" -version = "0.2.12" +version = "0.2.14" dependencies = [ "actix-web", "futures", diff --git a/Cargo.toml b/Cargo.toml index 29ed16a..72a18fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "discoursio-presence" -version = "0.2.12" +version = "0.2.14" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/README.md b/README.md index 96d8c1e..22b4d1b 100644 --- a/README.md +++ b/README.md @@ -12,4 +12,10 @@ ### Как это работает -Сервис подписывается на Redus PubSub каналы `new_reaction`, `new_follower`, `new_shout` и `chat:` и пересылает из них те сообщения, которые предназначены пользователю, который подписался на SSE по адресу `/presence/` \ No newline at end of file +Сервис подписывается на Redus PubSub каналы + - `new_reaction`, + - `new_follower:`, + - `new_shout` + - `chat:` + + Сервис пересылает из этих каналов те сообщения, которые предназначены пользователю, который подписался на SSE по адресу `/connect` токеном авторизации в заголовке `Authorization` \ No newline at end of file diff --git a/src/data.rs b/src/data.rs index 2fd4af3..7d4f909 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,15 +1,15 @@ -use reqwest::Client as HTTPClient; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use reqwest::Client as HTTPClient; use serde_json::json; use std::collections::HashMap; -use std::error::Error; use std::env; +use std::error::Error; pub async fn get_auth_id(token: &str) -> Result> { let auth_api_base = env::var("AUTH_URL")?; let (query_name, query_type) = match auth_api_base.contains("auth.discours.io") { + true => ("session", "query"), // authorizer _ => ("getSession", "mutation"), // v2 - true => ("session", "query") // authorizer }; let operation = "GetUserId"; let mut headers = HeaderMap::new(); @@ -22,8 +22,9 @@ pub async fn get_auth_id(token: &str) -> Result> { "variables": HashMap::::new() }); - let client = reqwest::Client::new(); - let response = client.post(&auth_api_base) + let client = HTTPClient::new(); + let response = client + .post(&auth_api_base) .headers(headers) .json(&gql) .send() @@ -31,7 +32,8 @@ pub async fn get_auth_id(token: &str) -> Result> { if response.status().is_success() { let r: HashMap = response.json().await?; - let user_id = r.get("data") + let user_id = r + .get("data") .and_then(|data| data.get(query_name)) .and_then(|query| query.get("user")) .and_then(|user| user.get("id")) @@ -41,19 +43,24 @@ pub async fn get_auth_id(token: &str) -> Result> { Some(id) => { println!("User ID retrieved: {}", id); Ok(id as i32) - }, + } None => { println!("No user ID found in the response"); - Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "No user ID found in the response"))) + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "No user ID found in the response", + ))) } } } else { println!("Request failed with status: {}", response.status()); - Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, format!("Request failed with status: {}", response.status())))) + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Request failed with status: {}", response.status()), + ))) } } - async fn get_shout_followers(shout_id: &str) -> Result, Box> { let api_base = env::var("API_BASE")?; let query = r#"query ShoutFollowers($shout: Int!) { @@ -75,11 +82,7 @@ async fn get_shout_followers(shout_id: &str) -> Result, Box> }); let client = reqwest::Client::new(); - let response = client - .post(&api_base) - .json(&body) - .send() - .await?; + let response = client.post(&api_base).json(&body).send().await?; if response.status().is_success() { let response_body: serde_json::Value = response.json().await?; @@ -93,50 +96,47 @@ async fn get_shout_followers(shout_id: &str) -> Result, Box> Ok(ids) } else { println!("Request failed with status: {}", response.status()); - Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, format!("Request failed with status: {}", response.status())))) + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Request failed with status: {}", response.status()), + ))) } } +pub async fn is_fitting( + listener_id: i32, + kind: String, + payload: HashMap, +) -> Result { + match kind.as_str() { + "new_follower" => { + // payload is AuthorFollower + Ok(payload.get("author").unwrap().to_string() == listener_id.to_string()) + } + "new_reaction" => { + // payload is Reaction + let shout_id = payload.get("shout").unwrap(); + let recipients = get_shout_followers(shout_id).await.unwrap(); -pub async fn is_fitting(listener_id: i32, payload: HashMap) -> Result { - match payload.get("kind") { - Some(kind) => { - match kind.as_str() { - "new_follower" => { - // payload is AuthorFollower - Ok(payload.get("author").unwrap().to_string() == listener_id.to_string()) - }, - "new_reaction" => { - // payload is Reaction - let shout_id = payload.get("shout").unwrap(); - let recipients = get_shout_followers(shout_id).await.unwrap(); - - Ok(recipients.contains(&listener_id)) - }, - "new_shout" => { - // payload is Shout - // TODO: check all community subscribers if no then - // check all topics subscribers if no then - // check all authors subscribers - Ok(true) - }, - "new_message" => { - // payload is Chat - let members_str = payload.get("members").unwrap(); - let members = serde_json::from_str::>(members_str).unwrap(); - Ok(members.contains(&listener_id.to_string())) - }, - _ => { - eprintln!("unknown payload kind"); - eprintln!("{:?}", payload); - Ok(false) - }, - } - }, - None => { - eprintln!("payload has no kind"); + Ok(recipients.contains(&listener_id)) + } + "new_shout" => { + // payload is Shout + // TODO: check all community subscribers if no then + // check all topics subscribers if no then + // check all authors subscribers + Ok(true) + } + "new_message" => { + // payload is Chat + let members_str = payload.get("members").unwrap(); + let members = serde_json::from_str::>(members_str).unwrap(); + Ok(members.contains(&listener_id.to_string())) + } + _ => { + eprintln!("unknown payload kind"); eprintln!("{:?}", payload); Ok(false) - }, + } } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 0656489..b8e6d2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use actix_web::{HttpRequest, web, App, HttpResponse, HttpServer, web::Bytes}; use actix_web::middleware::Logger; use redis::{Client, AsyncCommands}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::env; use futures::StreamExt; @@ -8,27 +9,32 @@ use tokio::sync::broadcast; use actix_web::error::{ErrorUnauthorized, ErrorInternalServerError as ServerError}; use std::sync::{Arc, Mutex}; use tokio::task::JoinHandle; - - -async fn test_handler() -> Result { - Ok(HttpResponse::Ok().body("Hello, World!")) -} - - mod data; + #[derive(Clone)] struct AppState { tasks: Arc>>>, redis: Client, } + +#[derive(Serialize, Deserialize)] +struct RedisMessageData { + payload: HashMap, + kind: String +} + async fn connect_handler( req: HttpRequest, state: web::Data, ) -> Result { let token = match req.headers().get("Authorization") { - Some(val) => val.to_str().unwrap_or(""), + Some(val) => val.to_str() + .unwrap_or("") + .split(" ") + .last() + .unwrap_or(""), None => return Err(ErrorUnauthorized("Unauthorized")), }; let listener_id = data::get_auth_id(&token).await.map_err(|e| { @@ -56,9 +62,9 @@ async fn connect_handler( let handle = tokio::spawn(async move { let conn = state_clone.redis.get_async_connection().await.unwrap(); let mut pubsub = conn.into_pubsub(); - - pubsub.subscribe("new_follower").await.unwrap(); - println!("'new_follower' subscribed"); + let followers_channel = format!("new_follower:{}", listener_id); + pubsub.subscribe(followers_channel.clone()).await.unwrap(); + println!("'{}' subscribed", followers_channel); pubsub.subscribe("new_shout").await.unwrap(); println!("'new_shout' subscribed"); pubsub.subscribe("new_reaction").await.unwrap(); @@ -71,9 +77,17 @@ async fn connect_handler( } while let Some(msg) = pubsub.on_message().next().await { - let payload: HashMap = msg.get_payload().unwrap(); - if data::is_fitting(listener_id, payload.clone()).await.is_ok() { - let _ = tx.send(serde_json::to_string(&payload).unwrap()); + let message_str: String = msg.get_payload().unwrap(); + let message_data: RedisMessageData = serde_json::from_str(&message_str).unwrap(); + if data::is_fitting(listener_id, message_data.kind.to_string(), message_data.payload).await.is_ok() { + let send_result = tx.send(message_str); + if send_result.is_err() { + let _ = con.srem::<&str, &i32, usize>("authors-online", &listener_id).await.map_err(|e| { + eprintln!("Failed to remove author from online list: {}", e); + ServerError("Internal Server Error") + }); + break; + } }; } }); @@ -95,32 +109,6 @@ async fn connect_handler( } -async fn disconnect_handler( - req: HttpRequest, - state: web::Data, -) -> Result { - let token = match req.headers().get("Authorization") { - Some(val) => val.to_str().unwrap_or(""), - None => return Err(ErrorUnauthorized("Unauthorized")), - }; - let listener_id = data::get_auth_id(&token).await.map_err(|e| { - eprintln!("TOKEN check failed: {}", e); - ErrorUnauthorized("Unauthorized") - })?; - if let Some(handle) = state.tasks.lock().unwrap().remove(&format!("{}", listener_id)) { - handle.abort(); - let mut con = state.redis.get_async_connection().await.map_err(|e| { - eprintln!("Failed to get async connection: {}", e); - ServerError("Internal Server Error") - })?; - con.srem::<&str, &i32, usize>("authors-online", &listener_id).await.map_err(|e| { - eprintln!("Failed to remove author from online list: {}", e); - ServerError("Internal Server Error") - })?; - } - Ok(HttpResponse::Ok().finish()) -} - #[actix_web::main] async fn main() -> std::io::Result<()> { let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| String::from("redis://127.0.0.1/")); @@ -136,7 +124,6 @@ async fn main() -> std::io::Result<()> { .wrap(Logger::default()) .app_data(web::Data::new(state.clone())) .route("/", web::get().to(connect_handler)) - .route("/test", web::post().to(test_handler)) }) .bind("0.0.0.0:8080")? .run()