use crate::user_handler::*; use crate::AppState; use crate::User; use axum::{ extract::{ ws::{Message, WebSocket}, ConnectInfo, State, WebSocketUpgrade, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use rand::prelude::SliceRandom; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::mpsc; /// Establish the WebSocket connection pub async fn websocket_connection_handler( ws: WebSocketUpgrade, // user_agent: Option>, ConnectInfo(addr): ConnectInfo, State(state): State>, ) -> impl IntoResponse { tracing::debug!("New connection from {}", &addr); ws.on_upgrade(move |socket| websocket_on_connection(socket, state, addr)) } /// This runs right after a WebSocket connection is established pub async fn websocket_on_connection(stream: WebSocket, state: Arc, addr: SocketAddr) { // Split channels to send and receive asynchronously. let (mut sender, mut receiver) = stream.split(); // Create channel for direct messages let (dm_tx, mut dm_rx) = mpsc::channel(30); let mut map = HashMap::new(); map.insert(addr, dm_tx.clone()); let _ = state .users_tx // add tx .send(UserHandlerMessage::NewUser { user: User::new( format!( "{} {}", state.first_names.choose(&mut rand::thread_rng()).unwrap(), state.last_names.choose(&mut rand::thread_rng()).unwrap(), ), dm_tx.clone(), ), addr, }) .await; // Subscribe to receive from global broadcast channel let mut rx = state.broadcast_tx.subscribe(); // Send messages to this client let mut send_task = tokio::spawn(async move { let mut broadcast = None; let mut dm = None; loop { tokio::select! { b = rx.recv() => broadcast = Some(b.unwrap()), d = dm_rx.recv() => dm = d, }; if let Some(msg) = &dm { if sender.send(Message::Text(msg.to_string())).await.is_err() { break; } else { dm = Option::None; } } else if let Some(msg) = &broadcast { if sender.send(Message::Text(msg.to_string())).await.is_err() { } else { broadcast = Option::None; } } } }); // Receive messages from this client let mut recv_task = tokio::spawn(async move { while let Some(Ok(message)) = receiver.next().await { state .messages_tx .send((addr.clone(), message.clone())) .await .unwrap(); } }); // If either task completes then abort the other tokio::select! { _ = (&mut send_task) => recv_task.abort(), _ = (&mut recv_task) => send_task.abort(), }; }