Protocol (sender_keys.rs): - SenderKey: symmetric key with chain ratchet (forward secrecy per chain) - generate(), rotate(), encrypt(), decrypt() - SenderKeyDistribution: share key via 1:1 encrypted channel - SenderKeyMessage: encrypted group message (O(1) instead of O(N)) - Chain key ratchets forward on each message (HKDF) - Generation counter for key rotation tracking - 4 tests: basic, multi-message, rotation, old-key rejection WireMessage: - GroupSenderKey variant: encrypted group message - SenderKeyDistribution variant: key sharing Server: dedup handles new variants. CLI TUI + recv: stub handlers for new message types. 23/23 protocol tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
191 lines
7.3 KiB
Rust
191 lines
7.3 KiB
Rust
//! WebSocket endpoint for real-time message delivery.
|
|
//!
|
|
//! Protocol:
|
|
//! 1. Client connects to /v1/ws/:fingerprint
|
|
//! 2. Server sends any queued messages (from DB)
|
|
//! 3. Server pushes new messages in real-time
|
|
//! 4. Client sends messages as binary WireMessage frames
|
|
//! 5. Server routes to recipient's WS or queues in DB
|
|
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
Path, State, WebSocketUpgrade,
|
|
},
|
|
response::IntoResponse,
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use warzone_protocol::message::WireMessage;
|
|
|
|
use crate::state::AppState;
|
|
|
|
/// Try to extract the message ID from raw bincode-serialized WireMessage bytes.
|
|
fn extract_message_id(data: &[u8]) -> Option<String> {
|
|
if let Ok(wire) = bincode::deserialize::<WireMessage>(data) {
|
|
match wire {
|
|
WireMessage::KeyExchange { id, .. } => Some(id),
|
|
WireMessage::Message { id, .. } => Some(id),
|
|
WireMessage::FileHeader { id, .. } => Some(id),
|
|
WireMessage::FileChunk { id, .. } => Some(id),
|
|
WireMessage::Receipt { message_id, .. } => Some(message_id),
|
|
WireMessage::GroupSenderKey { id, .. } => Some(id),
|
|
WireMessage::SenderKeyDistribution { sender_fingerprint, group_name, .. } => {
|
|
Some(format!("skd:{}:{}", sender_fingerprint, group_name))
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn routes() -> Router<AppState> {
|
|
Router::new().route("/ws/:fingerprint", get(ws_handler))
|
|
}
|
|
|
|
fn normalize_fp(fp: &str) -> String {
|
|
fp.chars()
|
|
.filter(|c| c.is_ascii_hexdigit())
|
|
.collect::<String>()
|
|
.to_lowercase()
|
|
}
|
|
|
|
async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<AppState>,
|
|
Path(fingerprint): Path<String>,
|
|
) -> impl IntoResponse {
|
|
let fp = normalize_fp(&fingerprint);
|
|
tracing::info!("WS upgrade request from {}", fp);
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state, fp))
|
|
}
|
|
|
|
async fn handle_socket(socket: WebSocket, state: AppState, fingerprint: String) {
|
|
let (mut ws_tx, mut ws_rx) = socket.split();
|
|
|
|
// Register for push delivery
|
|
let mut push_rx = state.register_ws(&fingerprint).await;
|
|
|
|
// Send any queued messages from DB
|
|
let prefix = format!("queue:{}", fingerprint);
|
|
let mut keys_to_delete = Vec::new();
|
|
for item in state.db.messages.scan_prefix(prefix.as_bytes()) {
|
|
if let Ok((key, value)) = item {
|
|
if ws_tx.send(Message::Binary(value.to_vec().into())).await.is_ok() {
|
|
keys_to_delete.push(key);
|
|
}
|
|
}
|
|
}
|
|
for key in &keys_to_delete {
|
|
let _ = state.db.messages.remove(key);
|
|
}
|
|
if !keys_to_delete.is_empty() {
|
|
tracing::info!("WS {}: flushed {} queued messages", fingerprint, keys_to_delete.len());
|
|
}
|
|
|
|
// Spawn task to forward push messages to WS
|
|
let _fp_clone = fingerprint.clone();
|
|
let mut push_task = tokio::spawn(async move {
|
|
while let Some(msg) = push_rx.recv().await {
|
|
if ws_tx.send(Message::Binary(msg.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
ws_tx
|
|
});
|
|
|
|
// Handle incoming messages from client
|
|
let state_clone = state.clone();
|
|
let fp_clone2 = fingerprint.clone();
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = ws_rx.next().await {
|
|
match msg {
|
|
Message::Binary(data) => {
|
|
// Parse as a simple { to: "fp", message: bytes } JSON
|
|
// Or just raw WireMessage bytes with a 32-byte fingerprint prefix
|
|
// For simplicity: first 32 hex chars = recipient fp, rest = message
|
|
if data.len() > 64 {
|
|
let header = String::from_utf8_lossy(&data[..64]).to_string();
|
|
let to_fp = normalize_fp(&header);
|
|
let message = &data[64..];
|
|
|
|
// Dedup: skip if we already processed this message ID
|
|
if let Some(msg_id) = extract_message_id(message) {
|
|
if state_clone.dedup.check_and_insert(&msg_id) {
|
|
tracing::debug!("WS dedup: dropping duplicate binary message {}", msg_id);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Try push to connected client first
|
|
if !state_clone.push_to_client(&to_fp, message).await {
|
|
// Queue in DB
|
|
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
|
|
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
|
|
}
|
|
|
|
tracing::debug!("WS {}: routed message to {}", fp_clone2, to_fp);
|
|
}
|
|
}
|
|
Message::Text(text) => {
|
|
// JSON format: {"to": "fp", "message": [bytes]}
|
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
|
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
|
|
let to_fp = normalize_fp(to);
|
|
if let Some(msg_arr) = parsed.get("message").and_then(|v| v.as_array()) {
|
|
let message: Vec<u8> = msg_arr.iter()
|
|
.filter_map(|v| v.as_u64().map(|n| n as u8))
|
|
.collect();
|
|
|
|
// Dedup: skip if we already processed this message ID
|
|
if let Some(msg_id) = extract_message_id(&message) {
|
|
if state_clone.dedup.check_and_insert(&msg_id) {
|
|
tracing::debug!("WS dedup: dropping duplicate JSON message {}", msg_id);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
if !state_clone.push_to_client(&to_fp, &message).await {
|
|
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
|
|
let _ = state_clone.db.messages.insert(key.as_bytes(), message);
|
|
}
|
|
|
|
// Renew alias TTL
|
|
crate::routes::messages::renew_alias_ttl(
|
|
&state_clone.db.aliases, &fp_clone2,
|
|
);
|
|
|
|
tracing::debug!("WS {}: routed JSON message to {}", fp_clone2, to_fp);
|
|
}
|
|
}
|
|
}
|
|
Message::Close(_) => break,
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Wait for either task to finish
|
|
tokio::select! {
|
|
_ = &mut push_task => {
|
|
recv_task.abort();
|
|
}
|
|
_ = &mut recv_task => {
|
|
push_task.abort();
|
|
}
|
|
}
|
|
|
|
// Unregister
|
|
// We can't easily get the sender ref here, so just clean up by fingerprint
|
|
// In production, use a unique connection ID
|
|
let mut conns = state.connections.lock().await;
|
|
if let Some(senders) = conns.get_mut(&fingerprint) {
|
|
senders.retain(|s| !s.is_closed());
|
|
if senders.is_empty() {
|
|
conns.remove(&fingerprint);
|
|
}
|
|
}
|
|
tracing::info!("WS {} disconnected", fingerprint);
|
|
}
|