//! WebSocket endpoint for real-time message delivery. //! //! Protocol: //! 1. Client connects to /v1/ws/:fingerprint //! 2. Server sends any queued messages (from DB) //! 3. Server pushes new messages in real-time //! 4. Client sends messages as binary WireMessage frames //! 5. Server routes to recipient's WS or queues in DB use axum::{ extract::{ ws::{Message, WebSocket}, Path, State, WebSocketUpgrade, }, response::IntoResponse, routing::get, Router, }; use futures_util::{SinkExt, StreamExt}; use crate::state::AppState; pub fn routes() -> Router { Router::new().route("/ws/:fingerprint", get(ws_handler)) } fn normalize_fp(fp: &str) -> String { fp.chars() .filter(|c| c.is_ascii_hexdigit()) .collect::() .to_lowercase() } async fn ws_handler( ws: WebSocketUpgrade, State(state): State, Path(fingerprint): Path, ) -> impl IntoResponse { let fp = normalize_fp(&fingerprint); tracing::info!("WS upgrade request from {}", fp); ws.on_upgrade(move |socket| handle_socket(socket, state, fp)) } async fn handle_socket(socket: WebSocket, state: AppState, fingerprint: String) { let (mut ws_tx, mut ws_rx) = socket.split(); // Register for push delivery let mut push_rx = state.register_ws(&fingerprint).await; // Send any queued messages from DB let prefix = format!("queue:{}", fingerprint); let mut keys_to_delete = Vec::new(); for item in state.db.messages.scan_prefix(prefix.as_bytes()) { if let Ok((key, value)) = item { if ws_tx.send(Message::Binary(value.to_vec().into())).await.is_ok() { keys_to_delete.push(key); } } } for key in &keys_to_delete { let _ = state.db.messages.remove(key); } if !keys_to_delete.is_empty() { tracing::info!("WS {}: flushed {} queued messages", fingerprint, keys_to_delete.len()); } // Spawn task to forward push messages to WS let fp_clone = fingerprint.clone(); let mut push_task = tokio::spawn(async move { while let Some(msg) = push_rx.recv().await { if ws_tx.send(Message::Binary(msg.into())).await.is_err() { break; } } ws_tx }); // Handle incoming messages from client let state_clone = state.clone(); let fp_clone2 = fingerprint.clone(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = ws_rx.next().await { match msg { Message::Binary(data) => { // Parse as a simple { to: "fp", message: bytes } JSON // Or just raw WireMessage bytes with a 32-byte fingerprint prefix // For simplicity: first 32 hex chars = recipient fp, rest = message if data.len() > 64 { let header = String::from_utf8_lossy(&data[..64]).to_string(); let to_fp = normalize_fp(&header); let message = &data[64..]; // Try push to connected client first if !state_clone.push_to_client(&to_fp, message).await { // Queue in DB let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4()); let _ = state_clone.db.messages.insert(key.as_bytes(), message); } tracing::debug!("WS {}: routed message to {}", fp_clone2, to_fp); } } Message::Text(text) => { // JSON format: {"to": "fp", "message": [bytes]} if let Ok(parsed) = serde_json::from_str::(&text) { let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or(""); let to_fp = normalize_fp(to); if let Some(msg_arr) = parsed.get("message").and_then(|v| v.as_array()) { let message: Vec = msg_arr.iter() .filter_map(|v| v.as_u64().map(|n| n as u8)) .collect(); if !state_clone.push_to_client(&to_fp, &message).await { let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4()); let _ = state_clone.db.messages.insert(key.as_bytes(), message); } // Renew alias TTL crate::routes::messages::renew_alias_ttl( &state_clone.db.aliases, &fp_clone2, ); tracing::debug!("WS {}: routed JSON message to {}", fp_clone2, to_fp); } } } Message::Close(_) => break, _ => {} } } }); // Wait for either task to finish tokio::select! { _ = &mut push_task => { recv_task.abort(); } _ = &mut recv_task => { push_task.abort(); } } // Unregister // We can't easily get the sender ref here, so just clean up by fingerprint // In production, use a unique connection ID let mut conns = state.connections.lock().await; if let Some(senders) = conns.get_mut(&fingerprint) { senders.retain(|s| !s.is_closed()); if senders.is_empty() { conns.remove(&fingerprint); } } tracing::info!("WS {} disconnected", fingerprint); }