Files
featherChat/warzone/crates/warzone-server/src/routes/ws.rs
Siavash Sameni 86da52acc4 v0.0.13: Sender Keys for efficient group encryption
Protocol (sender_keys.rs):
- SenderKey: symmetric key with chain ratchet (forward secrecy per chain)
- generate(), rotate(), encrypt(), decrypt()
- SenderKeyDistribution: share key via 1:1 encrypted channel
- SenderKeyMessage: encrypted group message (O(1) instead of O(N))
- Chain key ratchets forward on each message (HKDF)
- Generation counter for key rotation tracking
- 4 tests: basic, multi-message, rotation, old-key rejection

WireMessage:
- GroupSenderKey variant: encrypted group message
- SenderKeyDistribution variant: key sharing

Server: dedup handles new variants.
CLI TUI + recv: stub handlers for new message types.
23/23 protocol tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 13:23:10 +04:00

191 lines
7.3 KiB
Rust

//! WebSocket endpoint for real-time message delivery.
//!
//! Protocol:
//! 1. Client connects to /v1/ws/:fingerprint
//! 2. Server sends any queued messages (from DB)
//! 3. Server pushes new messages in real-time
//! 4. Client sends messages as binary WireMessage frames
//! 5. Server routes to recipient's WS or queues in DB
use axum::{
extract::{
ws::{Message, WebSocket},
Path, State, WebSocketUpgrade,
},
response::IntoResponse,
routing::get,
Router,
};
use futures_util::{SinkExt, StreamExt};
use warzone_protocol::message::WireMessage;
use crate::state::AppState;
/// Try to extract the message ID from raw bincode-serialized WireMessage bytes.
fn extract_message_id(data: &[u8]) -> Option<String> {
if let Ok(wire) = bincode::deserialize::<WireMessage>(data) {
match wire {
WireMessage::KeyExchange { id, .. } => Some(id),
WireMessage::Message { id, .. } => Some(id),
WireMessage::FileHeader { id, .. } => Some(id),
WireMessage::FileChunk { id, .. } => Some(id),
WireMessage::Receipt { message_id, .. } => Some(message_id),
WireMessage::GroupSenderKey { id, .. } => Some(id),
WireMessage::SenderKeyDistribution { sender_fingerprint, group_name, .. } => {
Some(format!("skd:{}:{}", sender_fingerprint, group_name))
}
}
} else {
None
}
}
pub fn routes() -> Router<AppState> {
Router::new().route("/ws/:fingerprint", get(ws_handler))
}
fn normalize_fp(fp: &str) -> String {
fp.chars()
.filter(|c| c.is_ascii_hexdigit())
.collect::<String>()
.to_lowercase()
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Path(fingerprint): Path<String>,
) -> impl IntoResponse {
let fp = normalize_fp(&fingerprint);
tracing::info!("WS upgrade request from {}", fp);
ws.on_upgrade(move |socket| handle_socket(socket, state, fp))
}
async fn handle_socket(socket: WebSocket, state: AppState, fingerprint: String) {
let (mut ws_tx, mut ws_rx) = socket.split();
// Register for push delivery
let mut push_rx = state.register_ws(&fingerprint).await;
// Send any queued messages from DB
let prefix = format!("queue:{}", fingerprint);
let mut keys_to_delete = Vec::new();
for item in state.db.messages.scan_prefix(prefix.as_bytes()) {
if let Ok((key, value)) = item {
if ws_tx.send(Message::Binary(value.to_vec().into())).await.is_ok() {
keys_to_delete.push(key);
}
}
}
for key in &keys_to_delete {
let _ = state.db.messages.remove(key);
}
if !keys_to_delete.is_empty() {
tracing::info!("WS {}: flushed {} queued messages", fingerprint, keys_to_delete.len());
}
// Spawn task to forward push messages to WS
let _fp_clone = fingerprint.clone();
let mut push_task = tokio::spawn(async move {
while let Some(msg) = push_rx.recv().await {
if ws_tx.send(Message::Binary(msg.into())).await.is_err() {
break;
}
}
ws_tx
});
// Handle incoming messages from client
let state_clone = state.clone();
let fp_clone2 = fingerprint.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_rx.next().await {
match msg {
Message::Binary(data) => {
// Parse as a simple { to: "fp", message: bytes } JSON
// Or just raw WireMessage bytes with a 32-byte fingerprint prefix
// For simplicity: first 32 hex chars = recipient fp, rest = message
if data.len() > 64 {
let header = String::from_utf8_lossy(&data[..64]).to_string();
let to_fp = normalize_fp(&header);
let message = &data[64..];
// Dedup: skip if we already processed this message ID
if let Some(msg_id) = extract_message_id(message) {
if state_clone.dedup.check_and_insert(&msg_id) {
tracing::debug!("WS dedup: dropping duplicate binary message {}", msg_id);
continue;
}
}
// Try push to connected client first
if !state_clone.push_to_client(&to_fp, message).await {
// Queue in DB
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
}
tracing::debug!("WS {}: routed message to {}", fp_clone2, to_fp);
}
}
Message::Text(text) => {
// JSON format: {"to": "fp", "message": [bytes]}
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
let to_fp = normalize_fp(to);
if let Some(msg_arr) = parsed.get("message").and_then(|v| v.as_array()) {
let message: Vec<u8> = msg_arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as u8))
.collect();
// Dedup: skip if we already processed this message ID
if let Some(msg_id) = extract_message_id(&message) {
if state_clone.dedup.check_and_insert(&msg_id) {
tracing::debug!("WS dedup: dropping duplicate JSON message {}", msg_id);
continue;
}
}
if !state_clone.push_to_client(&to_fp, &message).await {
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
}
// Renew alias TTL
crate::routes::messages::renew_alias_ttl(
&state_clone.db.aliases, &fp_clone2,
);
tracing::debug!("WS {}: routed JSON message to {}", fp_clone2, to_fp);
}
}
}
Message::Close(_) => break,
_ => {}
}
}
});
// Wait for either task to finish
tokio::select! {
_ = &mut push_task => {
recv_task.abort();
}
_ = &mut recv_task => {
push_task.abort();
}
}
// Unregister
// We can't easily get the sender ref here, so just clean up by fingerprint
// In production, use a unique connection ID
let mut conns = state.connections.lock().await;
if let Some(senders) = conns.get_mut(&fingerprint) {
senders.retain(|s| !s.is_closed());
if senders.is_empty() {
conns.remove(&fingerprint);
}
}
tracing::info!("WS {} disconnected", fingerprint);
}