//! Federation: two-server message relay via persistent WebSocket. //! //! Each server maintains a WS connection to its peer. Presence updates //! and message forwards flow over this single connection. Reconnects //! automatically on failure. use std::collections::HashSet; use std::sync::Arc; use tokio::sync::Mutex; /// Federation configuration loaded from JSON. #[derive(Clone, Debug, serde::Deserialize)] pub struct FederationConfig { pub server_id: String, pub shared_secret: String, pub peer: PeerConfig, } #[derive(Clone, Debug, serde::Deserialize)] pub struct PeerConfig { pub id: String, pub url: String, } /// Load federation config from a JSON file. pub fn load_config(path: &str) -> anyhow::Result { let data = std::fs::read_to_string(path) .map_err(|e| anyhow::anyhow!("failed to read federation config '{}': {}", path, e))?; let config: FederationConfig = serde_json::from_str(&data) .map_err(|e| anyhow::anyhow!("invalid federation config: {}", e))?; Ok(config) } /// Remote presence: which fingerprints are on the peer server. #[derive(Clone, Debug)] pub struct RemotePresence { pub peer_id: String, pub fingerprints: HashSet, pub last_updated: i64, pub connected: bool, } impl RemotePresence { pub fn new(peer_id: String) -> Self { RemotePresence { peer_id, fingerprints: HashSet::new(), last_updated: 0, connected: false, } } pub fn contains(&self, fp: &str) -> bool { self.connected && self.fingerprints.contains(fp) } } /// Sender for outgoing federation messages over the WS. pub type FederationSender = Arc>>>; /// Handle for communicating with the federation peer. #[derive(Clone)] pub struct FederationHandle { pub config: FederationConfig, pub remote_presence: Arc>, /// Channel to send messages over the outgoing WS to the peer. pub outgoing: FederationSender, /// HTTP client for one-shot requests (key fetch, etc.) pub client: reqwest::Client, } impl FederationHandle { pub fn new(config: FederationConfig) -> Self { let remote_presence = Arc::new(Mutex::new(RemotePresence::new( config.peer.id.clone(), ))); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(5)) .build() .expect("failed to build HTTP client"); FederationHandle { config, remote_presence, outgoing: Arc::new(Mutex::new(None)), client, } } /// Check if a fingerprint is known to be on the peer server. pub async fn is_remote(&self, fp: &str) -> bool { let rp = self.remote_presence.lock().await; rp.contains(fp) } /// Forward a message to the peer server via the persistent WS. pub async fn forward_message(&self, to_fp: &str, message: &[u8]) -> bool { let msg = serde_json::json!({ "type": "forward", "to": to_fp, "message": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message), "from_server": self.config.server_id, }); self.send_json(msg).await } /// Fetch a pre-key bundle from the peer server (HTTP GET fallback). /// Used when a local key lookup fails and the fingerprint is on the remote. pub async fn fetch_remote_bundle(&self, fingerprint: &str) -> Option> { let url = format!("{}/v1/keys/{}", self.config.peer.url, fingerprint); let resp = self.client.get(&url).send().await.ok()?; if !resp.status().is_success() { return None; } let data: serde_json::Value = resp.json().await.ok()?; let bundle_b64 = data.get("bundle")?.as_str()?; base64::Engine::decode(&base64::engine::general_purpose::STANDARD, bundle_b64).ok() } /// Resolve an alias on the peer server. /// Returns Some(fingerprint) if the peer knows this alias. pub async fn resolve_remote_alias(&self, alias: &str) -> Option { let url = format!("{}/v1/alias/resolve/{}", self.config.peer.url, alias); let resp = self.client.get(&url).send().await.ok()?; if !resp.status().is_success() { return None; } let data: serde_json::Value = resp.json().await.ok()?; // Check for error (alias not found on peer) if data.get("error").is_some() { return None; } data.get("fingerprint").and_then(|v| v.as_str()).map(String::from) } /// Check if an alias is already taken on the peer server. /// Returns true if the alias exists on the peer (taken). pub async fn is_alias_taken_remote(&self, alias: &str) -> bool { self.resolve_remote_alias(alias).await.is_some() } /// Push local presence to peer via the persistent WS. pub async fn push_presence(&self, fingerprints: Vec) -> bool { let msg = serde_json::json!({ "type": "presence", "server_id": self.config.server_id, "fingerprints": fingerprints, }); self.send_json(msg).await } /// Send a JSON message over the outgoing WS channel. async fn send_json(&self, msg: serde_json::Value) -> bool { let guard = self.outgoing.lock().await; if let Some(ref tx) = *guard { let json_str = serde_json::to_string(&msg).unwrap_or_default(); tx.send(json_str).is_ok() } else { false } } } /// Background task: connect to peer's WS endpoint, send auth, then loop. /// Handles reconnection on failure. pub async fn outgoing_ws_loop( handle: FederationHandle, state: crate::state::AppState, ) { let ws_url = handle.config.peer.url .replace("http://", "ws://") .replace("https://", "wss://"); let ws_url = format!("{}/v1/federation/ws", ws_url); loop { tracing::info!("Federation: connecting to peer {} at {}", handle.config.peer.id, ws_url); match tokio_tungstenite::connect_async(&ws_url).await { Ok((ws_stream, _)) => { tracing::info!("Federation: connected to peer {}", handle.config.peer.id); use futures_util::{SinkExt, StreamExt}; let (mut ws_tx, mut ws_rx) = ws_stream.split(); // Send auth as first message let auth_msg = serde_json::json!({ "type": "auth", "secret": handle.config.shared_secret, "server_id": handle.config.server_id, }); if ws_tx.send(tokio_tungstenite::tungstenite::Message::Text( serde_json::to_string(&auth_msg).unwrap_or_default() )).await.is_err() { tracing::warn!("Federation: failed to send auth to peer"); tokio::time::sleep(std::time::Duration::from_secs(3)).await; continue; } // Set up outgoing channel let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::(); { let mut guard = handle.outgoing.lock().await; *guard = Some(out_tx); } { let mut rp = handle.remote_presence.lock().await; rp.connected = true; } // Send initial presence let fps: Vec = { let conns = state.connections.lock().await; conns.keys().cloned().collect() }; let _ = handle.push_presence(fps).await; // Spawn task to forward outgoing channel + periodic ping to WS let send_task = tokio::spawn(async move { let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(15)); loop { tokio::select! { msg = out_rx.recv() => { match msg { Some(text) => { if ws_tx.send(tokio_tungstenite::tungstenite::Message::Text(text)).await.is_err() { break; } } None => break, } } _ = ping_interval.tick() => { if ws_tx.send(tokio_tungstenite::tungstenite::Message::Ping(vec![])).await.is_err() { break; } } } } }); // Spawn task to periodically re-push presence let presence_handle = handle.clone(); let presence_conns = state.connections.clone(); let presence_task = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(10)); loop { interval.tick().await; let fps: Vec = { let conns = presence_conns.lock().await; conns.keys().cloned().collect() }; if !presence_handle.push_presence(fps).await { break; } } }); // Read incoming messages from peer while let Some(Ok(msg)) = ws_rx.next().await { match msg { tokio_tungstenite::tungstenite::Message::Text(text) => { handle_incoming_federation_msg(&text, &handle, &state).await; } tokio_tungstenite::tungstenite::Message::Pong(_) => {} // keepalive response tokio_tungstenite::tungstenite::Message::Close(_) => break, _ => {} } } // Connection lost send_task.abort(); presence_task.abort(); { let mut guard = handle.outgoing.lock().await; *guard = None; } { let mut rp = handle.remote_presence.lock().await; rp.connected = false; rp.fingerprints.clear(); } tracing::warn!("Federation: lost connection to peer {}, reconnecting...", handle.config.peer.id); } Err(e) => { tracing::warn!("Federation: failed to connect to peer {}: {}", handle.config.peer.id, e); } } tokio::time::sleep(std::time::Duration::from_secs(3)).await; } } /// Process a single incoming JSON message from the federated peer WS. async fn handle_incoming_federation_msg( text: &str, handle: &FederationHandle, state: &crate::state::AppState, ) { let parsed: serde_json::Value = match serde_json::from_str(text) { Ok(v) => v, Err(_) => return, }; let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or(""); match msg_type { "presence" => { let fingerprints: Vec = parsed.get("fingerprints") .and_then(|v| v.as_array()) .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()) .unwrap_or_default(); let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("?"); let count = fingerprints.len(); let mut rp = handle.remote_presence.lock().await; rp.fingerprints = fingerprints.into_iter().collect(); rp.last_updated = chrono::Utc::now().timestamp(); tracing::debug!("Federation: received {} fingerprints from {}", count, server_id); } "forward" => { let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or(""); let message_b64 = parsed.get("message").and_then(|v| v.as_str()).unwrap_or(""); let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("?"); if let Ok(message) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, message_b64) { let delivered = state.push_to_client(to, &message).await; if !delivered { let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4()); let _ = state.db.messages.insert(key.as_bytes(), message.as_slice()); tracing::info!("Federation: queued message from {} for offline {}", from_server, to); } else { tracing::debug!("Federation: delivered message from {} to {}", from_server, to); } } } _ => { tracing::debug!("Federation: unknown message type '{}'", msg_type); } } }