diff --git a/Cargo.lock b/Cargo.lock index 24a5d55..27ae857 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -169,6 +169,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core 0.4.5", + "base64", "bytes", "futures-util", "http", @@ -184,8 +185,10 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.24.0", "tower", "tower-layer", "tower-service", @@ -220,7 +223,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -380,6 +383,12 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -3140,6 +3149,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.24.0", +] + [[package]] name = "tokio-tungstenite" version = "0.28.0" @@ -3149,7 +3170,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.28.0", ] [[package]] @@ -3366,6 +3387,24 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.28.0" @@ -4228,6 +4267,7 @@ dependencies = [ "async-trait", "axum 0.7.9", "bytes", + "futures-util", "prometheus", "quinn", "reqwest", @@ -4236,6 +4276,7 @@ dependencies = [ "serde_json", "tokio", "toml", + "tower-http", "tracing", "tracing-subscriber", "wzp-client", diff --git a/crates/wzp-proto/src/packet.rs b/crates/wzp-proto/src/packet.rs index 3a27376..9b326e8 100644 --- a/crates/wzp-proto/src/packet.rs +++ b/crates/wzp-proto/src/packet.rs @@ -46,6 +46,23 @@ impl MediaHeader { /// Header size in bytes on the wire. pub const WIRE_SIZE: usize = 12; + /// Create a default header for raw PCM relay (used by WebSocket bridge). + pub fn default_pcm() -> Self { + Self { + version: 0, + is_repair: false, + codec_id: CodecId::Opus24k, + has_quality_report: false, + fec_ratio_encoded: 0, + seq: 0, + timestamp: 0, + fec_block: 0, + fec_symbol: 0, + reserved: 0, + csrc_count: 0, + } + } + /// Encode the FEC ratio float (0.0-2.0+) to a 7-bit value (0-127). pub fn encode_fec_ratio(ratio: f32) -> u8 { // Map 0.0-2.0 to 0-127, clamping at 127 diff --git a/crates/wzp-relay/Cargo.toml b/crates/wzp-relay/Cargo.toml index b415727..df6dae4 100644 --- a/crates/wzp-relay/Cargo.toml +++ b/crates/wzp-relay/Cargo.toml @@ -25,7 +25,9 @@ serde_json = "1" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } quinn = { workspace = true } prometheus = "0.13" -axum = { version = "0.7", default-features = false, features = ["tokio", "http1"] } +axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "ws"] } +tower-http = { version = "0.6", features = ["fs"] } +futures-util = "0.3" [[bin]] name = "wzp-relay" diff --git a/crates/wzp-relay/src/config.rs b/crates/wzp-relay/src/config.rs index 493d700..01d9e14 100644 --- a/crates/wzp-relay/src/config.rs +++ b/crates/wzp-relay/src/config.rs @@ -39,6 +39,11 @@ pub struct RelayConfig { /// reducing per-packet QUIC datagram overhead. #[serde(default)] pub trunking_enabled: bool, + /// Port for the WebSocket listener (browser clients connect here). + /// If None, WebSocket support is disabled. + pub ws_port: Option, + /// Directory to serve static files from (HTML/JS/WASM for web clients). + pub static_dir: Option, } impl Default for RelayConfig { @@ -55,6 +60,8 @@ impl Default for RelayConfig { probe_targets: Vec::new(), probe_mesh: false, trunking_enabled: false, + ws_port: None, + static_dir: None, } } } diff --git a/crates/wzp-relay/src/lib.rs b/crates/wzp-relay/src/lib.rs index 1416cc8..a798c3a 100644 --- a/crates/wzp-relay/src/lib.rs +++ b/crates/wzp-relay/src/lib.rs @@ -19,6 +19,7 @@ pub mod room; pub mod route; pub mod session_mgr; pub mod trunk; +pub mod ws; pub use config::RelayConfig; pub use handshake::accept_handshake; diff --git a/crates/wzp-relay/src/main.rs b/crates/wzp-relay/src/main.rs index 531ebc9..4095ca7 100644 --- a/crates/wzp-relay/src/main.rs +++ b/crates/wzp-relay/src/main.rs @@ -68,6 +68,19 @@ fn parse_args() -> RelayConfig { "--trunking" => { config.trunking_enabled = true; } + "--ws-port" => { + i += 1; + config.ws_port = Some( + args.get(i).expect("--ws-port requires a port number") + .parse().expect("invalid --ws-port number"), + ); + } + "--static-dir" => { + i += 1; + config.static_dir = Some( + args.get(i).expect("--static-dir requires a directory path").to_string(), + ); + } "--mesh-status" => { // Print mesh table from a fresh registry and exit. // In practice this is useful after the relay has been running; @@ -89,6 +102,8 @@ fn parse_args() -> RelayConfig { eprintln!(" --probe-mesh Enable mesh mode (mark config flag, probes all --probe targets)."); eprintln!(" --mesh-status Print mesh health table and exit (diagnostic)."); eprintln!(" --trunking Enable trunk batching for outgoing media in room mode."); + eprintln!(" --ws-port WebSocket listener port for browser clients (e.g., 8080)."); + eprintln!(" --static-dir Directory to serve static files from (HTML/JS/WASM)."); eprintln!(); eprintln!("Room mode (default):"); eprintln!(" Clients join rooms by name. Packets forwarded to all others (SFU)."); @@ -233,6 +248,20 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { mesh.run_all().await }); } + // WebSocket server for browser clients + if let Some(ws_port) = config.ws_port { + let ws_state = wzp_relay::ws::WsState { + room_mgr: room_mgr.clone(), + session_mgr: session_mgr.clone(), + auth_url: config.auth_url.clone(), + metrics: metrics.clone(), + presence: presence.clone(), + }; + let static_dir = config.static_dir.clone(); + tokio::spawn(wzp_relay::ws::run_ws_server(ws_port, ws_state, static_dir)); + info!(ws_port, "WebSocket listener enabled for browser clients"); + } + if let Some(ref url) = config.auth_url { info!(url, "auth enabled — clients must present featherChat token"); } else { @@ -473,7 +502,7 @@ async fn main() -> anyhow::Result<()> { let participant_id = { let mut mgr = room_mgr.lock().await; - match mgr.join(&room_name, addr, transport.clone(), authenticated_fp.as_deref()) { + match mgr.join(&room_name, addr, room::ParticipantSender::Quic(transport.clone()), authenticated_fp.as_deref()) { Ok(id) => { metrics.active_rooms.set(mgr.list().len() as i64); id diff --git a/crates/wzp-relay/src/room.rs b/crates/wzp-relay/src/room.rs index 6a2e08c..0cb175d 100644 --- a/crates/wzp-relay/src/room.rs +++ b/crates/wzp-relay/src/room.rs @@ -27,11 +27,51 @@ fn next_id() -> ParticipantId { NEXT_PARTICIPANT_ID.fetch_add(1, Ordering::Relaxed) } +/// How to send data to a participant — either via QUIC transport or WebSocket channel. +#[derive(Clone)] +pub enum ParticipantSender { + Quic(Arc), + WebSocket(tokio::sync::mpsc::Sender), +} + +impl ParticipantSender { + /// Send raw bytes to this participant. + pub async fn send_raw(&self, data: &[u8]) -> Result<(), String> { + match self { + ParticipantSender::WebSocket(tx) => { + tx.try_send(Bytes::copy_from_slice(data)) + .map_err(|e| format!("ws send: {e}")) + } + ParticipantSender::Quic(transport) => { + let pkt = wzp_proto::MediaPacket { + header: wzp_proto::packet::MediaHeader::default_pcm(), + payload: Bytes::copy_from_slice(data), + quality_report: None, + }; + transport.send_media(&pkt).await.map_err(|e| format!("quic send: {e}")) + } + } + } + + /// Check if this is a QUIC participant. + pub fn is_quic(&self) -> bool { + matches!(self, ParticipantSender::Quic(_)) + } + + /// Get the QUIC transport if this is a QUIC participant. + pub fn as_quic(&self) -> Option<&Arc> { + match self { + ParticipantSender::Quic(t) => Some(t), + _ => None, + } + } +} + /// A participant in a room. struct Participant { id: ParticipantId, _addr: std::net::SocketAddr, - transport: Arc, + sender: ParticipantSender, } /// A room holding multiple participants. @@ -46,10 +86,10 @@ impl Room { } } - fn add(&mut self, addr: std::net::SocketAddr, transport: Arc) -> ParticipantId { + fn add(&mut self, addr: std::net::SocketAddr, sender: ParticipantSender) -> ParticipantId { let id = next_id(); info!(room_size = self.participants.len() + 1, participant = id, %addr, "joined room"); - self.participants.push(Participant { id, _addr: addr, transport }); + self.participants.push(Participant { id, _addr: addr, sender }); id } @@ -58,11 +98,11 @@ impl Room { info!(room_size = self.participants.len(), participant = id, "left room"); } - fn others(&self, exclude_id: ParticipantId) -> Vec> { + fn others(&self, exclude_id: ParticipantId) -> Vec { self.participants .iter() .filter(|p| p.id != exclude_id) - .map(|p| p.transport.clone()) + .map(|p| p.sender.clone()) .collect() } @@ -130,7 +170,7 @@ impl RoomManager { &mut self, room_name: &str, addr: std::net::SocketAddr, - transport: Arc, + sender: ParticipantSender, fingerprint: Option<&str>, ) -> Result { if !self.is_authorized(room_name, fingerprint) { @@ -138,7 +178,18 @@ impl RoomManager { return Err("not authorized for this room".to_string()); } let room = self.rooms.entry(room_name.to_string()).or_insert_with(Room::new); - Ok(room.add(addr, transport)) + Ok(room.add(addr, sender)) + } + + /// Join a room via WebSocket. Convenience wrapper around `join()`. + pub fn join_ws( + &mut self, + room_name: &str, + addr: std::net::SocketAddr, + sender: tokio::sync::mpsc::Sender, + fingerprint: Option<&str>, + ) -> Result { + self.join(room_name, addr, ParticipantSender::WebSocket(sender), fingerprint) } /// Leave a room. Removes the room if empty. @@ -152,12 +203,12 @@ impl RoomManager { } } - /// Get transports for all OTHER participants in a room. + /// Get senders for all OTHER participants in a room. pub fn others( &self, room_name: &str, participant_id: ParticipantId, - ) -> Vec> { + ) -> Vec { self.rooms .get(room_name) .map(|r| r.others(participant_id)) @@ -305,10 +356,14 @@ async fn run_participant_plain( // Forward to all others let pkt_bytes = pkt.payload.len() as u64; for other in &others { - // Best-effort: if one send fails, continue to others - if let Err(e) = other.send_media(&pkt).await { - // Don't log every failure — they'll be cleaned up when their recv loop breaks - let _ = e; + match other { + ParticipantSender::Quic(t) => { + let _ = t.send_media(&pkt).await; + } + ParticipantSender::WebSocket(_) => { + // WS clients receive raw payload bytes + let _ = other.send_raw(&pkt.payload).await; + } } } @@ -390,12 +445,20 @@ async fn run_participant_trunked( let pkt_bytes = pkt.payload.len() as u64; for other in &others { - let peer_addr = other.connection().remote_address(); - let fwd = forwarders - .entry(peer_addr) - .or_insert_with(|| TrunkedForwarder::new(other.clone(), sid_bytes)); - if let Err(e) = fwd.send(&pkt).await { - let _ = e; + match other { + ParticipantSender::Quic(t) => { + let peer_addr = t.connection().remote_address(); + let fwd = forwarders + .entry(peer_addr) + .or_insert_with(|| TrunkedForwarder::new(t.clone(), sid_bytes)); + if let Err(e) = fwd.send(&pkt).await { + let _ = e; + } + } + ParticipantSender::WebSocket(_) => { + // WS clients bypass trunking — send raw payload directly + let _ = other.send_raw(&pkt.payload).await; + } } } diff --git a/crates/wzp-relay/src/ws.rs b/crates/wzp-relay/src/ws.rs new file mode 100644 index 0000000..58d32cf --- /dev/null +++ b/crates/wzp-relay/src/ws.rs @@ -0,0 +1,243 @@ +//! WebSocket transport for browser clients. +//! +//! Browsers connect via `GET /ws/{room}` → WebSocket upgrade. +//! First message must be auth JSON (if auth is enabled). +//! Subsequent messages are binary PCM frames forwarded to/from the room. + +use std::net::SocketAddr; +use std::sync::Arc; + +use axum::{ + extract::{ + ws::{Message, WebSocket}, + Path, State, WebSocketUpgrade, + }, + response::IntoResponse, + routing::get, + Router, +}; +use bytes::Bytes; +use futures_util::{SinkExt, StreamExt}; +use tokio::sync::{mpsc, Mutex}; +use tower_http::services::ServeDir; +use tracing::{error, info, warn}; + +use crate::auth; +use crate::metrics::RelayMetrics; +use crate::presence::PresenceRegistry; +use crate::room::RoomManager; +use crate::session_mgr::SessionManager; + +/// Shared state for WebSocket handlers. +#[derive(Clone)] +pub struct WsState { + pub room_mgr: Arc>, + pub session_mgr: Arc>, + pub auth_url: Option, + pub metrics: Arc, + pub presence: Arc>, +} + +/// Start the WebSocket + static file server. +pub async fn run_ws_server(port: u16, state: WsState, static_dir: Option) { + let mut app = Router::new() + .route("/ws/{room}", get(ws_upgrade_handler)) + .with_state(state); + + if let Some(dir) = static_dir { + info!(dir = %dir, "serving static files"); + app = app.fallback_service(ServeDir::new(dir)); + } + + let addr: SocketAddr = ([0, 0, 0, 0], port).into(); + info!(%addr, "WebSocket server listening"); + + let listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind WS listener"); + axum::serve(listener, app).await.expect("WS server failed"); +} + +async fn ws_upgrade_handler( + Path(room): Path, + State(state): State, + ws: WebSocketUpgrade, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_ws_connection(socket, room, state)) +} + +async fn handle_ws_connection(socket: WebSocket, room: String, state: WsState) { + let (mut ws_tx, mut ws_rx) = socket.split(); + + // 1. Auth: if auth_url is set, first message must be {"type":"auth","token":"..."} + let fingerprint: Option = if let Some(ref auth_url) = state.auth_url { + match ws_rx.next().await { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::(&text) { + Ok(parsed) if parsed["type"] == "auth" => { + if let Some(token) = parsed["token"].as_str() { + match auth::validate_token(auth_url, token).await { + Ok(client) => { + state.metrics.auth_attempts.with_label_values(&["ok"]).inc(); + info!(fingerprint = %client.fingerprint, "WS authenticated"); + let _ = ws_tx + .send(Message::Text(r#"{"type":"auth_ok"}"#.into())) + .await; + Some(client.fingerprint) + } + Err(e) => { + state + .metrics + .auth_attempts + .with_label_values(&["fail"]) + .inc(); + let _ = ws_tx + .send(Message::Text( + format!(r#"{{"type":"auth_error","error":"{e}"}}"#) + .into(), + )) + .await; + warn!("WS auth failed: {e}"); + return; + } + } + } else { + warn!("WS auth: missing token field"); + return; + } + } + _ => { + warn!("WS: expected auth message as first frame"); + return; + } + } + } + _ => { + warn!("WS: connection closed before auth"); + return; + } + } + } else { + let _ = ws_tx + .send(Message::Text(r#"{"type":"auth_ok"}"#.into())) + .await; + None + }; + + // 2. Create mpsc channel for outbound frames (room → browser) + let (tx, mut rx) = mpsc::channel::(64); + + // 3. Create session + let session_id = { + let mut smgr = state.session_mgr.lock().await; + match smgr.create_session(&room, fingerprint.clone()) { + Ok(id) => id, + Err(e) => { + error!(room = %room, "WS session rejected: {e}"); + return; + } + } + }; + state.metrics.active_sessions.inc(); + + // 4. Join room with WS sender + let addr: SocketAddr = ([0, 0, 0, 0], 0).into(); + let participant_id = { + let mut mgr = state.room_mgr.lock().await; + match mgr.join_ws(&room, addr, tx, fingerprint.as_deref()) { + Ok(id) => { + state.metrics.active_rooms.set(mgr.list().len() as i64); + id + } + Err(e) => { + error!(room = %room, "WS room join denied: {e}"); + state.metrics.active_sessions.dec(); + let mut smgr = state.session_mgr.lock().await; + smgr.remove_session(session_id); + return; + } + } + }; + + // 5. Register presence + if let Some(ref fp) = fingerprint { + let mut reg = state.presence.lock().await; + reg.register_local(fp, None, Some(room.clone())); + } + + info!(room = %room, participant = participant_id, "WS client joined"); + + // 6. Outbound task: mpsc rx → WS binary frames + let send_task = tokio::spawn(async move { + while let Some(data) = rx.recv().await { + if ws_tx + .send(Message::Binary(data.to_vec().into())) + .await + .is_err() + { + break; + } + } + }); + + // 7. Inbound: WS recv → fan-out to room + loop { + match ws_rx.next().await { + Some(Ok(Message::Binary(data))) => { + let others = { + let mgr = state.room_mgr.lock().await; + mgr.others(&room, participant_id) + }; + for other in &others { + let _ = other.send_raw(&data).await; + } + state + .metrics + .packets_forwarded + .inc_by(others.len() as u64); + state + .metrics + .bytes_forwarded + .inc_by(data.len() as u64 * others.len() as u64); + } + Some(Ok(Message::Close(_))) | None => break, + _ => continue, + } + } + + // 8. Cleanup + send_task.abort(); + info!(room = %room, participant = participant_id, "WS client disconnected"); + + if let Some(ref fp) = fingerprint { + let mut reg = state.presence.lock().await; + reg.unregister_local(fp); + } + + { + let mut mgr = state.room_mgr.lock().await; + mgr.leave(&room, participant_id); + state.metrics.active_rooms.set(mgr.list().len() as i64); + } + + let session_id_str: String = session_id.iter().map(|b| format!("{b:02x}")).collect(); + state.metrics.remove_session_metrics(&session_id_str); + state.metrics.active_sessions.dec(); + + { + let mut smgr = state.session_mgr.lock().await; + smgr.remove_session(session_id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ws_state_is_clone() { + // WsState must be Clone for axum's State extractor + fn assert_clone() {} + assert_clone::(); + } +}