From 2ca25fd2bfa4a02615689f4d00311cb5d4b9adb6 Mon Sep 17 00:00:00 2001 From: Siavash Sameni Date: Fri, 27 Mar 2026 09:41:50 +0400 Subject: [PATCH] v0.0.5: WebSocket real-time messaging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server: - WS endpoint: /v1/ws/:fingerprint - Connection registry in AppState (fingerprint → WS senders) - On connect: flushes queued DB messages, then pushes in real-time - send_message: pushes to WS if connected, falls back to DB queue - Auto-cleanup on disconnect - WS accepts both binary and JSON text frames for sending Web client: - Replaces 2-second HTTP polling with persistent WebSocket - Auto-reconnects on disconnect (3-second backoff) - Sends via WS when connected, HTTP fallback - Messages arrive instantly (no polling delay) - "Real-time connection established" shown on connect HTTP polling still works: - CLI recv command uses HTTP (unchanged) - Web falls back to HTTP if WS fails - Mules/scripts can still use HTTP API Co-Authored-By: Claude Opus 4.6 (1M context) --- warzone/Cargo.lock | 106 +++++++++- warzone/Cargo.toml | 4 +- warzone/crates/warzone-server/Cargo.toml | 1 + .../warzone-server/src/routes/messages.rs | 15 +- .../crates/warzone-server/src/routes/mod.rs | 4 +- .../crates/warzone-server/src/routes/web.rs | 188 ++++++++++-------- .../crates/warzone-server/src/routes/ws.rs | 154 ++++++++++++++ warzone/crates/warzone-server/src/state.rs | 52 ++++- 8 files changed, 425 insertions(+), 99 deletions(-) create mode 100644 warzone/crates/warzone-server/src/routes/ws.rs diff --git a/warzone/Cargo.lock b/warzone/Cargo.lock index d43abab..4d6da74 100644 --- a/warzone/Cargo.lock +++ b/warzone/Cargo.lock @@ -141,6 +141,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", + "base64", "bytes", "futures-util", "http", @@ -159,8 +160,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower 0.5.3", "tower-layer", "tower-service", @@ -580,6 +583,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "der" version = "0.7.10" @@ -748,6 +757,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -767,6 +787,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", + "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "slab", @@ -1942,6 +1964,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2167,13 +2200,33 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -2269,6 +2322,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -2424,6 +2489,24 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -2502,6 +2585,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2555,7 +2644,7 @@ dependencies = [ [[package]] name = "warzone-client" -version = "0.0.4" +version = "0.0.5" dependencies = [ "anyhow", "argon2", @@ -2584,7 +2673,7 @@ dependencies = [ [[package]] name = "warzone-mule" -version = "0.0.4" +version = "0.0.5" dependencies = [ "anyhow", "clap", @@ -2593,7 +2682,7 @@ dependencies = [ [[package]] name = "warzone-protocol" -version = "0.0.4" +version = "0.0.5" dependencies = [ "base64", "bincode", @@ -2608,7 +2697,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "thiserror", + "thiserror 2.0.18", "uuid", "x25519-dalek", "zeroize", @@ -2616,7 +2705,7 @@ dependencies = [ [[package]] name = "warzone-server" -version = "0.0.4" +version = "0.0.5" dependencies = [ "anyhow", "axum", @@ -2625,12 +2714,13 @@ dependencies = [ "chrono", "clap", "ed25519-dalek", + "futures-util", "hex", "rand", "serde", "serde_json", "sled", - "thiserror", + "thiserror 2.0.18", "tokio", "tower 0.4.13", "tower-http 0.5.2", @@ -2642,7 +2732,7 @@ dependencies = [ [[package]] name = "warzone-wasm" -version = "0.0.4" +version = "0.0.5" dependencies = [ "base64", "bincode", diff --git a/warzone/Cargo.toml b/warzone/Cargo.toml index a7e6a79..c30babb 100644 --- a/warzone/Cargo.toml +++ b/warzone/Cargo.toml @@ -9,7 +9,7 @@ members = [ ] [workspace.package] -version = "0.0.4" +version = "0.0.5" edition = "2021" license = "MIT" rust-version = "1.75" @@ -37,7 +37,7 @@ bincode = "1" tokio = { version = "1", features = ["full"] } # Server -axum = "0.7" +axum = { version = "0.7", features = ["ws"] } tower = "0.4" tower-http = { version = "0.5", features = ["cors", "trace"] } diff --git a/warzone/crates/warzone-server/Cargo.toml b/warzone/crates/warzone-server/Cargo.toml index c2257f7..d1e6dd1 100644 --- a/warzone/crates/warzone-server/Cargo.toml +++ b/warzone/crates/warzone-server/Cargo.toml @@ -22,5 +22,6 @@ chrono.workspace = true hex.workspace = true base64.workspace = true rand.workspace = true +futures-util = "0.3" ed25519-dalek.workspace = true bincode.workspace = true diff --git a/warzone/crates/warzone-server/src/routes/messages.rs b/warzone/crates/warzone-server/src/routes/messages.rs index e4091e6..ed09511 100644 --- a/warzone/crates/warzone-server/src/routes/messages.rs +++ b/warzone/crates/warzone-server/src/routes/messages.rs @@ -9,7 +9,7 @@ use crate::errors::AppResult; use crate::state::AppState; /// Touch the alias TTL for a fingerprint (renew on authenticated action). -fn renew_alias_ttl(db: &sled::Tree, fp: &str) { +pub fn renew_alias_ttl(db: &sled::Tree, fp: &str) { let alias_key = format!("fp:{}", fp); if let Ok(Some(alias_bytes)) = db.get(alias_key.as_bytes()) { let alias = String::from_utf8_lossy(&alias_bytes).to_string(); @@ -54,9 +54,16 @@ async fn send_message( Json(req): Json, ) -> AppResult> { let to = normalize_fp(&req.to); - let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4()); - tracing::info!("Queuing message for {} ({} bytes)", to, req.message.len()); - state.db.messages.insert(key.as_bytes(), req.message)?; + + // Try WebSocket push first (instant delivery) + if state.push_to_client(&to, &req.message).await { + tracing::info!("Pushed message to {} via WS ({} bytes)", to, req.message.len()); + } else { + // Queue in DB (offline delivery) + let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4()); + tracing::info!("Queuing message for {} ({} bytes)", to, req.message.len()); + state.db.messages.insert(key.as_bytes(), req.message)?; + } // Renew sender's alias TTL (sending = authenticated action) if let Some(ref from) = req.from { diff --git a/warzone/crates/warzone-server/src/routes/mod.rs b/warzone/crates/warzone-server/src/routes/mod.rs index 37f3cf2..3a8ddbf 100644 --- a/warzone/crates/warzone-server/src/routes/mod.rs +++ b/warzone/crates/warzone-server/src/routes/mod.rs @@ -3,8 +3,9 @@ pub mod auth; mod groups; mod health; mod keys; -mod messages; +pub mod messages; mod web; +mod ws; use axum::Router; @@ -18,6 +19,7 @@ pub fn router() -> Router { .merge(groups::routes()) .merge(aliases::routes()) .merge(auth::routes()) + .merge(ws::routes()) } /// Web UI router (served at root, outside /v1) diff --git a/warzone/crates/warzone-server/src/routes/web.rs b/warzone/crates/warzone-server/src/routes/web.rs index 919e4d7..e09b837 100644 --- a/warzone/crates/warzone-server/src/routes/web.rs +++ b/warzone/crates/warzone-server/src/routes/web.rs @@ -157,9 +157,10 @@ let mySeedHex = ''; let sessions = {}; // peerFP -> { session: WasmSession, data: base64 } let peerBundles = {}; // peerFP -> bundle bytes let pollTimer = null; +let ws = null; // WebSocket connection let wasmReady = false; -const VERSION = '0.0.4'; +const VERSION = '0.0.5'; let DEBUG = true; // toggle with /debug command function dbg(...args) { @@ -281,101 +282,122 @@ async function sendEncrypted(peerFP, plaintext) { )); dbg('Sending wire message, size:', wireBytes.length); - await fetch(SERVER + '/v1/messages/send', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ + + // Prefer WebSocket, fall back to HTTP + if (ws && ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ to: fp, from: normFP(myFingerprint), message: Array.from(wireBytes) - }) - }); - dbg('Message sent'); + })); + dbg('Sent via WebSocket'); + } else { + await fetch(SERVER + '/v1/messages/send', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + to: fp, + from: normFP(myFingerprint), + message: Array.from(wireBytes) + }) + }); + dbg('Sent via HTTP (WS not connected)'); + } } -async function pollMessages() { - if (!wasmReady) return; +function connectWebSocket() { + const fp = normFP(myFingerprint); + const proto = location.protocol === 'https:' ? 'wss:' : 'ws:'; + const wsUrl = proto + '//' + location.host + '/v1/ws/' + fp; + dbg('Connecting WebSocket:', wsUrl); + + ws = new WebSocket(wsUrl); + ws.binaryType = 'arraybuffer'; + + ws.onopen = () => { + dbg('WebSocket connected'); + addSys('Real-time connection established'); + }; + + ws.onmessage = async (event) => { + const bytes = new Uint8Array(event.data); + dbg('WS received', bytes.length, 'bytes'); + await handleIncomingMessage(bytes); + }; + + ws.onclose = () => { + dbg('WebSocket closed, reconnecting in 3s...'); + addSys('Connection lost, reconnecting...'); + setTimeout(connectWebSocket, 3000); + }; + + ws.onerror = (e) => { + dbg('WebSocket error:', e); + }; +} + +async function handleIncomingMessage(bytes) { + dbg('Processing message,', bytes.length, 'bytes, sessions:', Object.keys(sessions)); + + // First try: KeyExchange (no existing session needed) + let decrypted = false; try { - const fp = normFP(myFingerprint); - const resp = await fetch(SERVER + '/v1/messages/poll/' + fp); - if (!resp.ok) return; - const msgs = await resp.json(); + const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, null); + const result = JSON.parse(resultStr); + dbg('Decrypted (KeyExchange) from:', result.sender); - dbg('Poll got', msgs.length, 'messages, sessions:', Object.keys(sessions)); + const senderFP = normFP(result.sender); + sessions[senderFP] = { data: result.session_data }; + localStorage.setItem('wz-sessions', JSON.stringify( + Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data])) + )); - for (let i = 0; i < msgs.length; i++) { - const b64 = msgs[i]; + let fromLabel = result.sender.slice(0, 19); + try { + const ar = await fetch(SERVER + '/v1/alias/whois/' + senderFP); + const ad = await ar.json(); + if (ad.alias) fromLabel = '@' + ad.alias; + } catch(e) {} + + addMsg(fromLabel, result.text, false); + decrypted = true; + } catch(e) { + dbg('KeyExchange failed:', e.message || e); + } + + // Second try: existing sessions + if (!decrypted) { + for (const [senderFP, sessData] of Object.entries(sessions)) { try { - const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0)); - dbg('Msg', i, ':', bytes.length, 'bytes, first 4:', Array.from(bytes.slice(0, 4))); + const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, sessData.data); + const result = JSON.parse(resultStr); + dbg('Decrypted with session', senderFP); - // First try: KeyExchange (no existing session needed) - let decrypted = false; + sessions[senderFP] = { data: result.session_data }; + localStorage.setItem('wz-sessions', JSON.stringify( + Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data])) + )); + + let fromLabel = result.sender.slice(0, 19); try { - dbg('Trying decrypt as KeyExchange (no session)...'); - const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, null); - const result = JSON.parse(resultStr); - dbg('Decrypted!', result.new_session ? 'new session' : 'existing', 'from:', result.sender); + const ar = await fetch(SERVER + '/v1/alias/whois/' + normFP(result.sender)); + const ad = await ar.json(); + if (ad.alias) fromLabel = '@' + ad.alias; + } catch(e2) {} - const senderFP = normFP(result.sender); - sessions[senderFP] = { data: result.session_data }; - localStorage.setItem('wz-sessions', JSON.stringify( - Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data])) - )); - - let fromLabel = result.sender.slice(0, 19); - try { - const ar = await fetch(SERVER + '/v1/alias/whois/' + senderFP); - const ad = await ar.json(); - if (ad.alias) fromLabel = '@' + ad.alias; - } catch(e) {} - - addMsg(fromLabel, result.text, false); - decrypted = true; - } catch(e) { - dbg('KeyExchange decrypt failed:', e.message || e); - } - - // Second try: existing sessions - if (!decrypted) { - for (const [senderFP, sessData] of Object.entries(sessions)) { - try { - dbg('Trying session for', senderFP); - const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, sessData.data); - const result = JSON.parse(resultStr); - dbg('Decrypted with session', senderFP, ':', result.text.slice(0, 30)); - - sessions[senderFP] = { data: result.session_data }; - localStorage.setItem('wz-sessions', JSON.stringify( - Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data])) - )); - - let fromLabel = result.sender.slice(0, 19); - try { - const ar = await fetch(SERVER + '/v1/alias/whois/' + normFP(result.sender)); - const ad = await ar.json(); - if (ad.alias) fromLabel = '@' + ad.alias; - } catch(e2) {} - - addMsg(fromLabel, result.text, false); - decrypted = true; - break; - } catch(e2) { - dbg('Session', senderFP, 'failed:', e2.message || e2); - } - } - } - - if (!decrypted) { - dbg('ALL decrypt attempts failed for msg', i); - addSys('[message could not be decrypted]'); - } - } catch(e) { - dbg('Message processing error:', e); - addSys('[failed to process message]'); + addMsg(fromLabel, result.text, false); + decrypted = true; + break; + } catch(e2) { + dbg('Session', senderFP, 'failed:', e2.message || e2); } } - } catch(e) { /* server offline */ } + } + + if (!decrypted) { + dbg('ALL decrypt attempts failed'); + addSys('[message could not be decrypted]'); + } } // Load saved sessions @@ -467,7 +489,7 @@ async function enterChat() { const savedPeer = localStorage.getItem('wz-peer'); if (savedPeer) $peerInput.value = savedPeer; - pollTimer = setInterval(pollMessages, 2000); + connectWebSocket(); $input.focus(); } diff --git a/warzone/crates/warzone-server/src/routes/ws.rs b/warzone/crates/warzone-server/src/routes/ws.rs new file mode 100644 index 0000000..6b0be7e --- /dev/null +++ b/warzone/crates/warzone-server/src/routes/ws.rs @@ -0,0 +1,154 @@ +//! WebSocket endpoint for real-time message delivery. +//! +//! Protocol: +//! 1. Client connects to /v1/ws/:fingerprint +//! 2. Server sends any queued messages (from DB) +//! 3. Server pushes new messages in real-time +//! 4. Client sends messages as binary WireMessage frames +//! 5. Server routes to recipient's WS or queues in DB + +use axum::{ + extract::{ + ws::{Message, WebSocket}, + Path, State, WebSocketUpgrade, + }, + response::IntoResponse, + routing::get, + Router, +}; +use futures_util::{SinkExt, StreamExt}; + +use crate::state::AppState; + +pub fn routes() -> Router { + Router::new().route("/ws/:fingerprint", get(ws_handler)) +} + +fn normalize_fp(fp: &str) -> String { + fp.chars() + .filter(|c| c.is_ascii_hexdigit()) + .collect::() + .to_lowercase() +} + +async fn ws_handler( + ws: WebSocketUpgrade, + State(state): State, + Path(fingerprint): Path, +) -> impl IntoResponse { + let fp = normalize_fp(&fingerprint); + tracing::info!("WS upgrade request from {}", fp); + ws.on_upgrade(move |socket| handle_socket(socket, state, fp)) +} + +async fn handle_socket(socket: WebSocket, state: AppState, fingerprint: String) { + let (mut ws_tx, mut ws_rx) = socket.split(); + + // Register for push delivery + let mut push_rx = state.register_ws(&fingerprint).await; + + // Send any queued messages from DB + let prefix = format!("queue:{}", fingerprint); + let mut keys_to_delete = Vec::new(); + for item in state.db.messages.scan_prefix(prefix.as_bytes()) { + if let Ok((key, value)) = item { + if ws_tx.send(Message::Binary(value.to_vec().into())).await.is_ok() { + keys_to_delete.push(key); + } + } + } + for key in &keys_to_delete { + let _ = state.db.messages.remove(key); + } + if !keys_to_delete.is_empty() { + tracing::info!("WS {}: flushed {} queued messages", fingerprint, keys_to_delete.len()); + } + + // Spawn task to forward push messages to WS + let fp_clone = fingerprint.clone(); + let mut push_task = tokio::spawn(async move { + while let Some(msg) = push_rx.recv().await { + if ws_tx.send(Message::Binary(msg.into())).await.is_err() { + break; + } + } + ws_tx + }); + + // Handle incoming messages from client + let state_clone = state.clone(); + let fp_clone2 = fingerprint.clone(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = ws_rx.next().await { + match msg { + Message::Binary(data) => { + // Parse as a simple { to: "fp", message: bytes } JSON + // Or just raw WireMessage bytes with a 32-byte fingerprint prefix + // For simplicity: first 32 hex chars = recipient fp, rest = message + if data.len() > 64 { + let header = String::from_utf8_lossy(&data[..64]).to_string(); + let to_fp = normalize_fp(&header); + let message = &data[64..]; + + // Try push to connected client first + if !state_clone.push_to_client(&to_fp, message).await { + // Queue in DB + let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4()); + let _ = state_clone.db.messages.insert(key.as_bytes(), message); + } + + tracing::debug!("WS {}: routed message to {}", fp_clone2, to_fp); + } + } + Message::Text(text) => { + // JSON format: {"to": "fp", "message": [bytes]} + if let Ok(parsed) = serde_json::from_str::(&text) { + let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or(""); + let to_fp = normalize_fp(to); + if let Some(msg_arr) = parsed.get("message").and_then(|v| v.as_array()) { + let message: Vec = msg_arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as u8)) + .collect(); + + if !state_clone.push_to_client(&to_fp, &message).await { + let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4()); + let _ = state_clone.db.messages.insert(key.as_bytes(), message); + } + + // Renew alias TTL + crate::routes::messages::renew_alias_ttl( + &state_clone.db.aliases, &fp_clone2, + ); + + tracing::debug!("WS {}: routed JSON message to {}", fp_clone2, to_fp); + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + + // Wait for either task to finish + tokio::select! { + _ = &mut push_task => { + recv_task.abort(); + } + _ = &mut recv_task => { + push_task.abort(); + } + } + + // Unregister + // We can't easily get the sender ref here, so just clean up by fingerprint + // In production, use a unique connection ID + let mut conns = state.connections.lock().await; + if let Some(senders) = conns.get_mut(&fingerprint) { + senders.retain(|s| !s.is_closed()); + if senders.is_empty() { + conns.remove(&fingerprint); + } + } + tracing::info!("WS {} disconnected", fingerprint); +} diff --git a/warzone/crates/warzone-server/src/state.rs b/warzone/crates/warzone-server/src/state.rs index 881e6f5..67dea12 100644 --- a/warzone/crates/warzone-server/src/state.rs +++ b/warzone/crates/warzone-server/src/state.rs @@ -1,15 +1,65 @@ +use std::collections::HashMap; use std::sync::Arc; +use tokio::sync::{broadcast, Mutex, mpsc}; use crate::db::Database; +/// Per-connection sender: messages are pushed here for instant delivery. +pub type WsSender = mpsc::UnboundedSender>; + +/// Connected clients: fingerprint → list of WS senders (multiple devices). +pub type Connections = Arc>>>; + #[derive(Clone)] pub struct AppState { pub db: Arc, + pub connections: Connections, } impl AppState { pub fn new(data_dir: &str) -> anyhow::Result { let db = Database::open(data_dir)?; - Ok(AppState { db: Arc::new(db) }) + Ok(AppState { + db: Arc::new(db), + connections: Arc::new(Mutex::new(HashMap::new())), + }) + } + + /// Try to push a message to a connected client. Returns true if delivered. + pub async fn push_to_client(&self, fingerprint: &str, message: &[u8]) -> bool { + let conns = self.connections.lock().await; + if let Some(senders) = conns.get(fingerprint) { + let mut delivered = false; + for sender in senders { + if sender.send(message.to_vec()).is_ok() { + delivered = true; + } + } + delivered + } else { + false + } + } + + /// Register a WS connection for a fingerprint. + pub async fn register_ws(&self, fingerprint: &str) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded_channel(); + let mut conns = self.connections.lock().await; + conns.entry(fingerprint.to_string()).or_default().push(tx); + tracing::info!("WS registered for {} ({} total connections)", fingerprint, + conns.values().map(|v| v.len()).sum::()); + rx + } + + /// Unregister a WS connection. + pub async fn unregister_ws(&self, fingerprint: &str, sender: &WsSender) { + let mut conns = self.connections.lock().await; + if let Some(senders) = conns.get_mut(fingerprint) { + senders.retain(|s| !s.same_channel(sender)); + if senders.is_empty() { + conns.remove(fingerprint); + } + } + tracing::info!("WS unregistered for {}", fingerprint); } }