use crate::user_handler::*; use crate::AppState; use crate::User; use axum::{ extract::{ ws::{Message, WebSocket}, ConnectInfo, State, WebSocketUpgrade, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use rand::prelude::SliceRandom; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::mpsc; /// Establish the WebSocket connection pub async fn websocket_connection_handler( ws: WebSocketUpgrade, // user_agent: Option>, ConnectInfo(addr): ConnectInfo, State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(move |socket| websocket_on_connection(socket, state, addr)) } /// This runs right after a WebSocket connection is established pub async fn websocket_on_connection(stream: WebSocket, state: Arc, addr: SocketAddr) { // Split channels to send and receive asynchronously. let (mut sender, mut receiver) = stream.split(); // Create channel for direct messages let (dm_tx, mut dm_rx) = mpsc::channel(1000000); let mut map = HashMap::new(); map.insert(addr, dm_tx.clone()); // Roll for username and re-roll if taken let mut username; loop { username = format!( "{} {}", state.first_names.choose(&mut rand::thread_rng()).unwrap(), state.last_names.choose(&mut rand::thread_rng()).unwrap(), ); if !state.reserved_names.read().unwrap().contains(&username) { break; } } let tx = state.users_tx.clone(); let msg = UserHandlerMessage::NewUser(User::new(username, dm_tx.clone()), addr); tokio::spawn(async move { tx.send(msg).await.expect("User handler is down") }); // Subscribe to receive from global broadcast channel let mut rx = state.broadcast_tx.subscribe(); // Send messages to this client let mut send_task = tokio::spawn(async move { let mut broadcast = None; let mut dm = None; loop { tokio::select! { b = rx.recv() => broadcast = Some(b.unwrap()), d = dm_rx.recv() => dm = d, }; if let Some(msg) = &dm { if sender.send(Message::Text(msg.to_string())).await.is_err() { break; } else { dm = Option::None; } } else if let Some(msg) = &broadcast { if sender.send(Message::Text(msg.to_string())).await.is_err() { } else { broadcast = Option::None; } } } }); // Receive messages from this client let mut recv_task = tokio::spawn(async move { while let Some(Ok(message)) = receiver.next().await { if let Err(e) = state .messages_tx .send((addr.clone(), message.clone())) .await { tracing::error!("Error relaying received message: {}", e) }; } }); // If either task completes then abort the other tokio::select! { _ = (&mut send_task) => recv_task.abort(), _ = (&mut recv_task) => send_task.abort(), }; }