104 lines
3.2 KiB
Rust
104 lines
3.2 KiB
Rust
use crate::user_handler::*;
|
|
use crate::AppState;
|
|
use crate::User;
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
ConnectInfo, State, WebSocketUpgrade,
|
|
},
|
|
response::IntoResponse,
|
|
};
|
|
use futures::{SinkExt, StreamExt};
|
|
use rand::prelude::SliceRandom;
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
/// Establish the WebSocket connection
|
|
pub async fn websocket_connection_handler(
|
|
ws: WebSocketUpgrade,
|
|
// user_agent: Option<TypedHeader<headers::UserAgent>>,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
State(state): State<Arc<AppState>>,
|
|
) -> impl IntoResponse {
|
|
ws.on_upgrade(move |socket| websocket_on_connection(socket, state, addr))
|
|
}
|
|
|
|
/// This runs right after a WebSocket connection is established
|
|
pub async fn websocket_on_connection(stream: WebSocket, state: Arc<AppState>, addr: SocketAddr) {
|
|
// Split channels to send and receive asynchronously.
|
|
let (mut sender, mut receiver) = stream.split();
|
|
|
|
// Create channel for direct messages
|
|
let (dm_tx, mut dm_rx) = mpsc::channel(1000000);
|
|
|
|
let mut map = HashMap::new();
|
|
map.insert(addr, dm_tx.clone());
|
|
|
|
// Roll for username and re-roll if taken
|
|
let mut username;
|
|
|
|
loop {
|
|
username = format!(
|
|
"{} {}",
|
|
state.first_names.choose(&mut rand::thread_rng()).unwrap(),
|
|
state.last_names.choose(&mut rand::thread_rng()).unwrap(),
|
|
);
|
|
|
|
if !state.reserved_names.read().unwrap().contains(&username) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
let tx = state.users_tx.clone();
|
|
let msg = UserHandlerMessage::NewUser(User::new(username, dm_tx.clone()), addr);
|
|
tokio::spawn(async move { tx.send(msg).await.expect("User handler is down") });
|
|
|
|
// Subscribe to receive from global broadcast channel
|
|
let mut rx = state.broadcast_tx.subscribe();
|
|
|
|
// Send messages to this client
|
|
let mut send_task = tokio::spawn(async move {
|
|
let mut broadcast = None;
|
|
let mut dm = None;
|
|
loop {
|
|
tokio::select! {
|
|
b = rx.recv() => broadcast = Some(b.unwrap()),
|
|
d = dm_rx.recv() => dm = d,
|
|
};
|
|
|
|
if let Some(msg) = &dm {
|
|
if sender.send(Message::Text(msg.to_string())).await.is_err() {
|
|
break;
|
|
} else {
|
|
dm = Option::None;
|
|
}
|
|
} else if let Some(msg) = &broadcast {
|
|
if sender.send(Message::Text(msg.to_string())).await.is_err() {
|
|
} else {
|
|
broadcast = Option::None;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Receive messages from this client
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(message)) = receiver.next().await {
|
|
if let Err(e) = state
|
|
.messages_tx
|
|
.send((addr.clone(), message.clone()))
|
|
.await
|
|
{
|
|
tracing::error!("Error relaying received message: {}", e)
|
|
};
|
|
}
|
|
});
|
|
|
|
// If either task completes then abort the other
|
|
tokio::select! {
|
|
_ = (&mut send_task) => recv_task.abort(),
|
|
_ = (&mut recv_task) => send_task.abort(),
|
|
};
|
|
}
|