v0.0.5: WebSocket real-time messaging
Server: - WS endpoint: /v1/ws/:fingerprint - Connection registry in AppState (fingerprint → WS senders) - On connect: flushes queued DB messages, then pushes in real-time - send_message: pushes to WS if connected, falls back to DB queue - Auto-cleanup on disconnect - WS accepts both binary and JSON text frames for sending Web client: - Replaces 2-second HTTP polling with persistent WebSocket - Auto-reconnects on disconnect (3-second backoff) - Sends via WS when connected, HTTP fallback - Messages arrive instantly (no polling delay) - "Real-time connection established" shown on connect HTTP polling still works: - CLI recv command uses HTTP (unchanged) - Web falls back to HTTP if WS fails - Mules/scripts can still use HTTP API Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
106
warzone/Cargo.lock
generated
106
warzone/Cargo.lock
generated
@@ -141,6 +141,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
@@ -159,8 +160,10 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower 0.5.3",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -580,6 +583,12 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.10"
|
||||
@@ -748,6 +757,17 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.32"
|
||||
@@ -767,6 +787,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
@@ -1942,6 +1964,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.9"
|
||||
@@ -2167,13 +2200,33 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||
dependencies = [
|
||||
"thiserror-impl 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
"thiserror-impl 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2269,6 +2322,18 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.18"
|
||||
@@ -2424,6 +2489,24 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha1",
|
||||
"thiserror 1.0.69",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.19.0"
|
||||
@@ -2502,6 +2585,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
@@ -2555,7 +2644,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warzone-client"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"argon2",
|
||||
@@ -2584,7 +2673,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warzone-mule"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
@@ -2593,7 +2682,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warzone-protocol"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bincode",
|
||||
@@ -2608,7 +2697,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"uuid",
|
||||
"x25519-dalek",
|
||||
"zeroize",
|
||||
@@ -2616,7 +2705,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warzone-server"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
@@ -2625,12 +2714,13 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"ed25519-dalek",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sled",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tower 0.4.13",
|
||||
"tower-http 0.5.2",
|
||||
@@ -2642,7 +2732,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warzone-wasm"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bincode",
|
||||
|
||||
@@ -9,7 +9,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
rust-version = "1.75"
|
||||
@@ -37,7 +37,7 @@ bincode = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# Server
|
||||
axum = "0.7"
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
tower = "0.4"
|
||||
tower-http = { version = "0.5", features = ["cors", "trace"] }
|
||||
|
||||
|
||||
@@ -22,5 +22,6 @@ chrono.workspace = true
|
||||
hex.workspace = true
|
||||
base64.workspace = true
|
||||
rand.workspace = true
|
||||
futures-util = "0.3"
|
||||
ed25519-dalek.workspace = true
|
||||
bincode.workspace = true
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::errors::AppResult;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Touch the alias TTL for a fingerprint (renew on authenticated action).
|
||||
fn renew_alias_ttl(db: &sled::Tree, fp: &str) {
|
||||
pub fn renew_alias_ttl(db: &sled::Tree, fp: &str) {
|
||||
let alias_key = format!("fp:{}", fp);
|
||||
if let Ok(Some(alias_bytes)) = db.get(alias_key.as_bytes()) {
|
||||
let alias = String::from_utf8_lossy(&alias_bytes).to_string();
|
||||
@@ -54,9 +54,16 @@ async fn send_message(
|
||||
Json(req): Json<SendRequest>,
|
||||
) -> AppResult<Json<serde_json::Value>> {
|
||||
let to = normalize_fp(&req.to);
|
||||
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
|
||||
tracing::info!("Queuing message for {} ({} bytes)", to, req.message.len());
|
||||
state.db.messages.insert(key.as_bytes(), req.message)?;
|
||||
|
||||
// Try WebSocket push first (instant delivery)
|
||||
if state.push_to_client(&to, &req.message).await {
|
||||
tracing::info!("Pushed message to {} via WS ({} bytes)", to, req.message.len());
|
||||
} else {
|
||||
// Queue in DB (offline delivery)
|
||||
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
|
||||
tracing::info!("Queuing message for {} ({} bytes)", to, req.message.len());
|
||||
state.db.messages.insert(key.as_bytes(), req.message)?;
|
||||
}
|
||||
|
||||
// Renew sender's alias TTL (sending = authenticated action)
|
||||
if let Some(ref from) = req.from {
|
||||
|
||||
@@ -3,8 +3,9 @@ pub mod auth;
|
||||
mod groups;
|
||||
mod health;
|
||||
mod keys;
|
||||
mod messages;
|
||||
pub mod messages;
|
||||
mod web;
|
||||
mod ws;
|
||||
|
||||
use axum::Router;
|
||||
|
||||
@@ -18,6 +19,7 @@ pub fn router() -> Router<AppState> {
|
||||
.merge(groups::routes())
|
||||
.merge(aliases::routes())
|
||||
.merge(auth::routes())
|
||||
.merge(ws::routes())
|
||||
}
|
||||
|
||||
/// Web UI router (served at root, outside /v1)
|
||||
|
||||
@@ -157,9 +157,10 @@ let mySeedHex = '';
|
||||
let sessions = {}; // peerFP -> { session: WasmSession, data: base64 }
|
||||
let peerBundles = {}; // peerFP -> bundle bytes
|
||||
let pollTimer = null;
|
||||
let ws = null; // WebSocket connection
|
||||
let wasmReady = false;
|
||||
|
||||
const VERSION = '0.0.4';
|
||||
const VERSION = '0.0.5';
|
||||
let DEBUG = true; // toggle with /debug command
|
||||
|
||||
function dbg(...args) {
|
||||
@@ -281,101 +282,122 @@ async function sendEncrypted(peerFP, plaintext) {
|
||||
));
|
||||
|
||||
dbg('Sending wire message, size:', wireBytes.length);
|
||||
await fetch(SERVER + '/v1/messages/send', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
|
||||
// Prefer WebSocket, fall back to HTTP
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
to: fp,
|
||||
from: normFP(myFingerprint),
|
||||
message: Array.from(wireBytes)
|
||||
})
|
||||
});
|
||||
dbg('Message sent');
|
||||
}));
|
||||
dbg('Sent via WebSocket');
|
||||
} else {
|
||||
await fetch(SERVER + '/v1/messages/send', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
to: fp,
|
||||
from: normFP(myFingerprint),
|
||||
message: Array.from(wireBytes)
|
||||
})
|
||||
});
|
||||
dbg('Sent via HTTP (WS not connected)');
|
||||
}
|
||||
}
|
||||
|
||||
async function pollMessages() {
|
||||
if (!wasmReady) return;
|
||||
function connectWebSocket() {
|
||||
const fp = normFP(myFingerprint);
|
||||
const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = proto + '//' + location.host + '/v1/ws/' + fp;
|
||||
dbg('Connecting WebSocket:', wsUrl);
|
||||
|
||||
ws = new WebSocket(wsUrl);
|
||||
ws.binaryType = 'arraybuffer';
|
||||
|
||||
ws.onopen = () => {
|
||||
dbg('WebSocket connected');
|
||||
addSys('Real-time connection established');
|
||||
};
|
||||
|
||||
ws.onmessage = async (event) => {
|
||||
const bytes = new Uint8Array(event.data);
|
||||
dbg('WS received', bytes.length, 'bytes');
|
||||
await handleIncomingMessage(bytes);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
dbg('WebSocket closed, reconnecting in 3s...');
|
||||
addSys('Connection lost, reconnecting...');
|
||||
setTimeout(connectWebSocket, 3000);
|
||||
};
|
||||
|
||||
ws.onerror = (e) => {
|
||||
dbg('WebSocket error:', e);
|
||||
};
|
||||
}
|
||||
|
||||
async function handleIncomingMessage(bytes) {
|
||||
dbg('Processing message,', bytes.length, 'bytes, sessions:', Object.keys(sessions));
|
||||
|
||||
// First try: KeyExchange (no existing session needed)
|
||||
let decrypted = false;
|
||||
try {
|
||||
const fp = normFP(myFingerprint);
|
||||
const resp = await fetch(SERVER + '/v1/messages/poll/' + fp);
|
||||
if (!resp.ok) return;
|
||||
const msgs = await resp.json();
|
||||
const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, null);
|
||||
const result = JSON.parse(resultStr);
|
||||
dbg('Decrypted (KeyExchange) from:', result.sender);
|
||||
|
||||
dbg('Poll got', msgs.length, 'messages, sessions:', Object.keys(sessions));
|
||||
const senderFP = normFP(result.sender);
|
||||
sessions[senderFP] = { data: result.session_data };
|
||||
localStorage.setItem('wz-sessions', JSON.stringify(
|
||||
Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data]))
|
||||
));
|
||||
|
||||
for (let i = 0; i < msgs.length; i++) {
|
||||
const b64 = msgs[i];
|
||||
let fromLabel = result.sender.slice(0, 19);
|
||||
try {
|
||||
const ar = await fetch(SERVER + '/v1/alias/whois/' + senderFP);
|
||||
const ad = await ar.json();
|
||||
if (ad.alias) fromLabel = '@' + ad.alias;
|
||||
} catch(e) {}
|
||||
|
||||
addMsg(fromLabel, result.text, false);
|
||||
decrypted = true;
|
||||
} catch(e) {
|
||||
dbg('KeyExchange failed:', e.message || e);
|
||||
}
|
||||
|
||||
// Second try: existing sessions
|
||||
if (!decrypted) {
|
||||
for (const [senderFP, sessData] of Object.entries(sessions)) {
|
||||
try {
|
||||
const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
|
||||
dbg('Msg', i, ':', bytes.length, 'bytes, first 4:', Array.from(bytes.slice(0, 4)));
|
||||
const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, sessData.data);
|
||||
const result = JSON.parse(resultStr);
|
||||
dbg('Decrypted with session', senderFP);
|
||||
|
||||
// First try: KeyExchange (no existing session needed)
|
||||
let decrypted = false;
|
||||
sessions[senderFP] = { data: result.session_data };
|
||||
localStorage.setItem('wz-sessions', JSON.stringify(
|
||||
Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data]))
|
||||
));
|
||||
|
||||
let fromLabel = result.sender.slice(0, 19);
|
||||
try {
|
||||
dbg('Trying decrypt as KeyExchange (no session)...');
|
||||
const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, null);
|
||||
const result = JSON.parse(resultStr);
|
||||
dbg('Decrypted!', result.new_session ? 'new session' : 'existing', 'from:', result.sender);
|
||||
const ar = await fetch(SERVER + '/v1/alias/whois/' + normFP(result.sender));
|
||||
const ad = await ar.json();
|
||||
if (ad.alias) fromLabel = '@' + ad.alias;
|
||||
} catch(e2) {}
|
||||
|
||||
const senderFP = normFP(result.sender);
|
||||
sessions[senderFP] = { data: result.session_data };
|
||||
localStorage.setItem('wz-sessions', JSON.stringify(
|
||||
Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data]))
|
||||
));
|
||||
|
||||
let fromLabel = result.sender.slice(0, 19);
|
||||
try {
|
||||
const ar = await fetch(SERVER + '/v1/alias/whois/' + senderFP);
|
||||
const ad = await ar.json();
|
||||
if (ad.alias) fromLabel = '@' + ad.alias;
|
||||
} catch(e) {}
|
||||
|
||||
addMsg(fromLabel, result.text, false);
|
||||
decrypted = true;
|
||||
} catch(e) {
|
||||
dbg('KeyExchange decrypt failed:', e.message || e);
|
||||
}
|
||||
|
||||
// Second try: existing sessions
|
||||
if (!decrypted) {
|
||||
for (const [senderFP, sessData] of Object.entries(sessions)) {
|
||||
try {
|
||||
dbg('Trying session for', senderFP);
|
||||
const resultStr = decrypt_wire_message(mySeedHex, mySpkSecretHex, bytes, sessData.data);
|
||||
const result = JSON.parse(resultStr);
|
||||
dbg('Decrypted with session', senderFP, ':', result.text.slice(0, 30));
|
||||
|
||||
sessions[senderFP] = { data: result.session_data };
|
||||
localStorage.setItem('wz-sessions', JSON.stringify(
|
||||
Object.fromEntries(Object.entries(sessions).map(([k,v]) => [k, v.data]))
|
||||
));
|
||||
|
||||
let fromLabel = result.sender.slice(0, 19);
|
||||
try {
|
||||
const ar = await fetch(SERVER + '/v1/alias/whois/' + normFP(result.sender));
|
||||
const ad = await ar.json();
|
||||
if (ad.alias) fromLabel = '@' + ad.alias;
|
||||
} catch(e2) {}
|
||||
|
||||
addMsg(fromLabel, result.text, false);
|
||||
decrypted = true;
|
||||
break;
|
||||
} catch(e2) {
|
||||
dbg('Session', senderFP, 'failed:', e2.message || e2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!decrypted) {
|
||||
dbg('ALL decrypt attempts failed for msg', i);
|
||||
addSys('[message could not be decrypted]');
|
||||
}
|
||||
} catch(e) {
|
||||
dbg('Message processing error:', e);
|
||||
addSys('[failed to process message]');
|
||||
addMsg(fromLabel, result.text, false);
|
||||
decrypted = true;
|
||||
break;
|
||||
} catch(e2) {
|
||||
dbg('Session', senderFP, 'failed:', e2.message || e2);
|
||||
}
|
||||
}
|
||||
} catch(e) { /* server offline */ }
|
||||
}
|
||||
|
||||
if (!decrypted) {
|
||||
dbg('ALL decrypt attempts failed');
|
||||
addSys('[message could not be decrypted]');
|
||||
}
|
||||
}
|
||||
|
||||
// Load saved sessions
|
||||
@@ -467,7 +489,7 @@ async function enterChat() {
|
||||
const savedPeer = localStorage.getItem('wz-peer');
|
||||
if (savedPeer) $peerInput.value = savedPeer;
|
||||
|
||||
pollTimer = setInterval(pollMessages, 2000);
|
||||
connectWebSocket();
|
||||
$input.focus();
|
||||
}
|
||||
|
||||
|
||||
154
warzone/crates/warzone-server/src/routes/ws.rs
Normal file
154
warzone/crates/warzone-server/src/routes/ws.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! WebSocket endpoint for real-time message delivery.
|
||||
//!
|
||||
//! Protocol:
|
||||
//! 1. Client connects to /v1/ws/:fingerprint
|
||||
//! 2. Server sends any queued messages (from DB)
|
||||
//! 3. Server pushes new messages in real-time
|
||||
//! 4. Client sends messages as binary WireMessage frames
|
||||
//! 5. Server routes to recipient's WS or queues in DB
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Path, State, WebSocketUpgrade,
|
||||
},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub fn routes() -> Router<AppState> {
|
||||
Router::new().route("/ws/:fingerprint", get(ws_handler))
|
||||
}
|
||||
|
||||
fn normalize_fp(fp: &str) -> String {
|
||||
fp.chars()
|
||||
.filter(|c| c.is_ascii_hexdigit())
|
||||
.collect::<String>()
|
||||
.to_lowercase()
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
Path(fingerprint): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let fp = normalize_fp(&fingerprint);
|
||||
tracing::info!("WS upgrade request from {}", fp);
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, fp))
|
||||
}
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, fingerprint: String) {
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// Register for push delivery
|
||||
let mut push_rx = state.register_ws(&fingerprint).await;
|
||||
|
||||
// Send any queued messages from DB
|
||||
let prefix = format!("queue:{}", fingerprint);
|
||||
let mut keys_to_delete = Vec::new();
|
||||
for item in state.db.messages.scan_prefix(prefix.as_bytes()) {
|
||||
if let Ok((key, value)) = item {
|
||||
if ws_tx.send(Message::Binary(value.to_vec().into())).await.is_ok() {
|
||||
keys_to_delete.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
for key in &keys_to_delete {
|
||||
let _ = state.db.messages.remove(key);
|
||||
}
|
||||
if !keys_to_delete.is_empty() {
|
||||
tracing::info!("WS {}: flushed {} queued messages", fingerprint, keys_to_delete.len());
|
||||
}
|
||||
|
||||
// Spawn task to forward push messages to WS
|
||||
let fp_clone = fingerprint.clone();
|
||||
let mut push_task = tokio::spawn(async move {
|
||||
while let Some(msg) = push_rx.recv().await {
|
||||
if ws_tx.send(Message::Binary(msg.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ws_tx
|
||||
});
|
||||
|
||||
// Handle incoming messages from client
|
||||
let state_clone = state.clone();
|
||||
let fp_clone2 = fingerprint.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
match msg {
|
||||
Message::Binary(data) => {
|
||||
// Parse as a simple { to: "fp", message: bytes } JSON
|
||||
// Or just raw WireMessage bytes with a 32-byte fingerprint prefix
|
||||
// For simplicity: first 32 hex chars = recipient fp, rest = message
|
||||
if data.len() > 64 {
|
||||
let header = String::from_utf8_lossy(&data[..64]).to_string();
|
||||
let to_fp = normalize_fp(&header);
|
||||
let message = &data[64..];
|
||||
|
||||
// Try push to connected client first
|
||||
if !state_clone.push_to_client(&to_fp, message).await {
|
||||
// Queue in DB
|
||||
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
|
||||
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
|
||||
}
|
||||
|
||||
tracing::debug!("WS {}: routed message to {}", fp_clone2, to_fp);
|
||||
}
|
||||
}
|
||||
Message::Text(text) => {
|
||||
// JSON format: {"to": "fp", "message": [bytes]}
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let to_fp = normalize_fp(to);
|
||||
if let Some(msg_arr) = parsed.get("message").and_then(|v| v.as_array()) {
|
||||
let message: Vec<u8> = msg_arr.iter()
|
||||
.filter_map(|v| v.as_u64().map(|n| n as u8))
|
||||
.collect();
|
||||
|
||||
if !state_clone.push_to_client(&to_fp, &message).await {
|
||||
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
|
||||
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
|
||||
}
|
||||
|
||||
// Renew alias TTL
|
||||
crate::routes::messages::renew_alias_ttl(
|
||||
&state_clone.db.aliases, &fp_clone2,
|
||||
);
|
||||
|
||||
tracing::debug!("WS {}: routed JSON message to {}", fp_clone2, to_fp);
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either task to finish
|
||||
tokio::select! {
|
||||
_ = &mut push_task => {
|
||||
recv_task.abort();
|
||||
}
|
||||
_ = &mut recv_task => {
|
||||
push_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister
|
||||
// We can't easily get the sender ref here, so just clean up by fingerprint
|
||||
// In production, use a unique connection ID
|
||||
let mut conns = state.connections.lock().await;
|
||||
if let Some(senders) = conns.get_mut(&fingerprint) {
|
||||
senders.retain(|s| !s.is_closed());
|
||||
if senders.is_empty() {
|
||||
conns.remove(&fingerprint);
|
||||
}
|
||||
}
|
||||
tracing::info!("WS {} disconnected", fingerprint);
|
||||
}
|
||||
@@ -1,15 +1,65 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, Mutex, mpsc};
|
||||
|
||||
use crate::db::Database;
|
||||
|
||||
/// Per-connection sender: messages are pushed here for instant delivery.
|
||||
pub type WsSender = mpsc::UnboundedSender<Vec<u8>>;
|
||||
|
||||
/// Connected clients: fingerprint → list of WS senders (multiple devices).
|
||||
pub type Connections = Arc<Mutex<HashMap<String, Vec<WsSender>>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub connections: Connections,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(data_dir: &str) -> anyhow::Result<Self> {
|
||||
let db = Database::open(data_dir)?;
|
||||
Ok(AppState { db: Arc::new(db) })
|
||||
Ok(AppState {
|
||||
db: Arc::new(db),
|
||||
connections: Arc::new(Mutex::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to push a message to a connected client. Returns true if delivered.
|
||||
pub async fn push_to_client(&self, fingerprint: &str, message: &[u8]) -> bool {
|
||||
let conns = self.connections.lock().await;
|
||||
if let Some(senders) = conns.get(fingerprint) {
|
||||
let mut delivered = false;
|
||||
for sender in senders {
|
||||
if sender.send(message.to_vec()).is_ok() {
|
||||
delivered = true;
|
||||
}
|
||||
}
|
||||
delivered
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a WS connection for a fingerprint.
|
||||
pub async fn register_ws(&self, fingerprint: &str) -> mpsc::UnboundedReceiver<Vec<u8>> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let mut conns = self.connections.lock().await;
|
||||
conns.entry(fingerprint.to_string()).or_default().push(tx);
|
||||
tracing::info!("WS registered for {} ({} total connections)", fingerprint,
|
||||
conns.values().map(|v| v.len()).sum::<usize>());
|
||||
rx
|
||||
}
|
||||
|
||||
/// Unregister a WS connection.
|
||||
pub async fn unregister_ws(&self, fingerprint: &str, sender: &WsSender) {
|
||||
let mut conns = self.connections.lock().await;
|
||||
if let Some(senders) = conns.get_mut(fingerprint) {
|
||||
senders.retain(|s| !s.same_channel(sender));
|
||||
if senders.is_empty() {
|
||||
conns.remove(fingerprint);
|
||||
}
|
||||
}
|
||||
tracing::info!("WS unregistered for {}", fingerprint);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user