disconnect-handler

This commit is contained in:
Tony Rewin 2023-10-06 17:57:54 +03:00
parent d58aed3648
commit 531cbb4458

View File

@ -5,19 +5,27 @@ use std::env;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use actix_web::error::{ErrorUnauthorized, ErrorInternalServerError as ServerError}; use actix_web::error::{ErrorUnauthorized, ErrorInternalServerError as ServerError};
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
mod data; mod data;
async fn sse_handler( #[derive(Clone)]
struct AppState {
tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
redis: Client,
}
async fn connect_handler(
token: web::Path<String>, token: web::Path<String>,
redis: web::Data<Client>, state: web::Data<AppState>,
) -> Result<HttpResponse, actix_web::Error> { ) -> Result<HttpResponse, actix_web::Error> {
let listener_id = data::get_auth_id(&token).await.map_err(|e| { let listener_id = data::get_auth_id(&token).await.map_err(|e| {
eprintln!("TOKEN check failed: {}", e); eprintln!("TOKEN check failed: {}", e);
ErrorUnauthorized("Unauthorized") ErrorUnauthorized("Unauthorized")
})?; })?;
let mut con = redis.get_async_connection().await.map_err(|e| { let mut con = state.redis.get_async_connection().await.map_err(|e| {
eprintln!("Failed to get async connection: {}", e); eprintln!("Failed to get async connection: {}", e);
ServerError("Internal Server Error") ServerError("Internal Server Error")
})?; })?;
@ -33,8 +41,8 @@ async fn sse_handler(
})?; })?;
let (tx, mut rx) = broadcast::channel(100); let (tx, mut rx) = broadcast::channel(100);
let _handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let conn = redis.get_async_connection().await.unwrap(); let conn = state.redis.get_async_connection().await.unwrap();
let mut pubsub = conn.into_pubsub(); let mut pubsub = conn.into_pubsub();
pubsub.subscribe("new_follower").await.unwrap(); pubsub.subscribe("new_follower").await.unwrap();
@ -57,6 +65,10 @@ async fn sse_handler(
}; };
} }
}); });
state.tasks
.lock()
.unwrap()
.insert(format!("{}", listener_id.clone()), handle);
let server_event = rx.recv().await.map_err(|e| { let server_event = rx.recv().await.map_err(|e| {
eprintln!("Failed to receive server event: {}", e); eprintln!("Failed to receive server event: {}", e);
@ -70,17 +82,44 @@ async fn sse_handler(
.streaming(server_event_stream)) .streaming(server_event_stream))
} }
async fn disconnect_handler(
token: web::Path<String>,
state: web::Data<AppState>,
) -> Result<HttpResponse, actix_web::Error> {
let listener_id = data::get_auth_id(&token).await.map_err(|e| {
eprintln!("TOKEN check failed: {}", e);
ErrorUnauthorized("Unauthorized")
})?;
if let Some(handle) = state.tasks.lock().unwrap().remove(&format!("{}", listener_id)) {
handle.abort();
let mut con = state.redis.get_async_connection().await.map_err(|e| {
eprintln!("Failed to get async connection: {}", e);
ServerError("Internal Server Error")
})?;
con.srem::<&str, &i32, usize>("authors-online", &listener_id).await.map_err(|e| {
eprintln!("Failed to remove author from online list: {}", e);
ServerError("Internal Server Error")
})?;
}
Ok(HttpResponse::Ok().finish())
}
#[actix_web::main] #[actix_web::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| String::from("redis://127.0.0.1/")); let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| String::from("redis://127.0.0.1/"));
let client = redis::Client::open(redis_url.clone()).unwrap(); let client = redis::Client::open(redis_url.clone()).unwrap();
let tasks = Arc::new(Mutex::new(HashMap::new()));
let state = AppState {
tasks: tasks.clone(),
redis: client.clone(),
};
println!("Connecting to Redis: {}", redis_url); println!("Connecting to Redis: {}", redis_url);
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
.app_data(web::Data::new(client.clone())) .app_data(web::Data::new(state.clone()))
.route("/connect", web::get().to(sse_handler)) .route("/connect", web::get().to(connect_handler))
.route("/disconnect", web::get().to(sse_handler)) .route("/disconnect", web::get().to(disconnect_handler))
}) })
.bind("127.0.0.1:8080")? .bind("127.0.0.1:8080")?
.run() .run()