cards/server/src/websocket.rs
2024-08-22 17:05:26 -04:00

100 lines
3 KiB
Rust

use crate::user_handler::*;
use crate::AppState;
use crate::User;
use axum::{
extract::{
ws::{Message, WebSocket},
ConnectInfo, State, WebSocketUpgrade,
},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use rand::prelude::SliceRandom;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
/// Establish the WebSocket connection
pub async fn websocket_connection_handler(
ws: WebSocketUpgrade,
// user_agent: Option<TypedHeader<headers::UserAgent>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
tracing::debug!("New connection from {}", &addr);
ws.on_upgrade(move |socket| websocket_on_connection(socket, state, addr))
}
/// This runs right after a WebSocket connection is established
pub async fn websocket_on_connection(stream: WebSocket, state: Arc<AppState>, addr: SocketAddr) {
// Split channels to send and receive asynchronously.
let (mut sender, mut receiver) = stream.split();
// Create channel for direct messages
let (dm_tx, mut dm_rx) = mpsc::channel(30);
let mut map = HashMap::new();
map.insert(addr, dm_tx.clone());
state
.users_tx
.send(UserHandlerMessage::NewUser(
User::new(
format!(
"{} {}",
state.first_names.choose(&mut rand::thread_rng()).unwrap(),
state.last_names.choose(&mut rand::thread_rng()).unwrap(),
),
dm_tx.clone(),
),
addr,
))
.await
.unwrap();
// Subscribe to receive from global broadcast channel
let mut rx = state.broadcast_tx.subscribe();
// Send messages to this client
let mut send_task = tokio::spawn(async move {
let mut broadcast = None;
let mut dm = None;
loop {
tokio::select! {
b = rx.recv() => broadcast = Some(b.unwrap()),
d = dm_rx.recv() => dm = d,
};
if let Some(msg) = &dm {
if sender.send(Message::Text(msg.to_string())).await.is_err() {
break;
} else {
dm = Option::None;
}
} else if let Some(msg) = &broadcast {
if sender.send(Message::Text(msg.to_string())).await.is_err() {
} else {
broadcast = Option::None;
}
}
}
});
// Receive messages from this client
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(message)) = receiver.next().await {
state
.messages_tx
.send((addr.clone(), message.clone()))
.await
.unwrap();
}
});
// If either task completes then abort the other
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
}