refactor: federation uses persistent WS instead of HTTP polling

- Server-to-server communication via WebSocket at /v1/federation/ws
- Auth as first WS frame (shared secret), presence + forwards over same connection
- Auto-reconnect every 3s on disconnect, instant presence push on connect
- Replaces HTTP REST polling (no more 5s intervals, lower latency)
- Removed dead HMAC helpers (auth is now direct secret comparison over WS)
- Simplified ARCHITECTURE.md mermaid diagrams for Gitea rendering

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Siavash Sameni
2026-03-28 16:56:13 +04:00
parent 3e0889e5dc
commit f8eaf30bb4
7 changed files with 364 additions and 306 deletions

View File

@@ -27,3 +27,4 @@ ed25519-dalek.workspace = true
bincode.workspace = true
sha2.workspace = true
reqwest = { workspace = true, features = ["rustls-tls", "json"] }
tokio-tungstenite.workspace = true

View File

@@ -1,12 +1,12 @@
//! Federation: two-server message relay with shared-secret authentication.
//! Federation: two-server message relay via persistent WebSocket.
//!
//! Each server periodically announces its connected clients to the peer.
//! When a message is destined for a remote client, it's forwarded via HTTP.
//! Each server maintains a WS connection to its peer. Presence updates
//! and message forwards flow over this single connection. Reconnects
//! automatically on failure.
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use sha2::{Sha256, Digest};
/// Federation configuration loaded from JSON.
#[derive(Clone, Debug, serde::Deserialize)]
@@ -14,8 +14,6 @@ pub struct FederationConfig {
pub server_id: String,
pub shared_secret: String,
pub peer: PeerConfig,
#[serde(default = "default_interval")]
pub presence_interval_secs: u64,
}
#[derive(Clone, Debug, serde::Deserialize)]
@@ -24,9 +22,7 @@ pub struct PeerConfig {
pub url: String,
}
fn default_interval() -> u64 { 5 }
/// Load federation config from a JSON file. Returns None if path is empty.
/// Load federation config from a JSON file.
pub fn load_config(path: &str) -> anyhow::Result<FederationConfig> {
let data = std::fs::read_to_string(path)
.map_err(|e| anyhow::anyhow!("failed to read federation config '{}': {}", path, e))?;
@@ -38,175 +34,227 @@ pub fn load_config(path: &str) -> anyhow::Result<FederationConfig> {
/// Remote presence: which fingerprints are on the peer server.
#[derive(Clone, Debug)]
pub struct RemotePresence {
pub peer_url: String,
pub peer_id: String,
pub fingerprints: HashSet<String>,
pub last_updated: i64,
pub connected: bool,
}
impl RemotePresence {
pub fn new(peer_url: String, peer_id: String) -> Self {
pub fn new(peer_id: String) -> Self {
RemotePresence {
peer_url,
peer_id,
fingerprints: HashSet::new(),
last_updated: 0,
connected: false,
}
}
/// Check if a fingerprint is on the remote server.
pub fn contains(&self, fp: &str) -> bool {
self.fingerprints.contains(fp)
}
/// Is the peer still alive? (heard from within 3 intervals)
pub fn is_alive(&self, interval_secs: u64) -> bool {
let now = chrono::Utc::now().timestamp();
now - self.last_updated < (interval_secs as i64 * 3)
self.connected && self.fingerprints.contains(fp)
}
}
/// Sender for outgoing federation messages over the WS.
pub type FederationSender = Arc<Mutex<Option<tokio::sync::mpsc::UnboundedSender<String>>>>;
/// Handle for communicating with the federation peer.
#[derive(Clone)]
pub struct FederationHandle {
pub config: FederationConfig,
pub client: reqwest::Client,
pub remote_presence: Arc<Mutex<RemotePresence>>,
/// Channel to send messages over the outgoing WS to the peer.
pub outgoing: FederationSender,
}
impl FederationHandle {
pub fn new(config: FederationConfig) -> Self {
let remote_presence = Arc::new(Mutex::new(RemotePresence::new(
config.peer.url.clone(),
config.peer.id.clone(),
)));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("failed to build HTTP client");
FederationHandle { config, client, remote_presence }
FederationHandle {
config,
remote_presence,
outgoing: Arc::new(Mutex::new(None)),
}
}
/// Check if a fingerprint is known to be on the peer server.
pub async fn is_remote(&self, fp: &str) -> bool {
let rp = self.remote_presence.lock().await;
rp.is_alive(self.config.presence_interval_secs) && rp.contains(fp)
rp.contains(fp)
}
/// Forward a message to the peer server for delivery.
/// Returns true if the peer accepted it.
/// Forward a message to the peer server via the persistent WS.
pub async fn forward_message(&self, to_fp: &str, message: &[u8]) -> bool {
let url = format!("{}/v1/federation/forward", self.config.peer.url);
let body = serde_json::json!({
let msg = serde_json::json!({
"type": "forward",
"to": to_fp,
"message": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message),
"from_server": self.config.server_id,
});
let body_str = serde_json::to_string(&body).unwrap_or_default();
let token = compute_token(&self.config.shared_secret, body_str.as_bytes());
match self.client.post(&url)
.header("X-Federation-Token", &token)
.header("Content-Type", "application/json")
.body(body_str)
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
tracing::debug!("Federation: forwarded message to {} for {}", self.config.peer.id, to_fp);
true
}
Ok(resp) => {
tracing::warn!("Federation: peer {} rejected forward: {}", self.config.peer.id, resp.status());
false
}
Err(e) => {
tracing::warn!("Federation: failed to forward to {}: {}", self.config.peer.id, e);
false
}
}
self.send_json(msg).await
}
/// Send our local presence to the peer.
pub async fn announce_presence(&self, fingerprints: Vec<String>) -> bool {
let url = format!("{}/v1/federation/presence", self.config.peer.url);
let body = serde_json::json!({
/// Push local presence to peer via the persistent WS.
pub async fn push_presence(&self, fingerprints: Vec<String>) -> bool {
let msg = serde_json::json!({
"type": "presence",
"server_id": self.config.server_id,
"fingerprints": fingerprints,
"timestamp": chrono::Utc::now().timestamp(),
});
let body_str = serde_json::to_string(&body).unwrap_or_default();
let token = compute_token(&self.config.shared_secret, body_str.as_bytes());
self.send_json(msg).await
}
match self.client.post(&url)
.header("X-Federation-Token", &token)
.header("Content-Type", "application/json")
.body(body_str)
.send()
.await
{
Ok(resp) if resp.status().is_success() => true,
Ok(resp) => {
tracing::warn!("Federation: presence announce to {} failed: {}", self.config.peer.id, resp.status());
false
}
Err(e) => {
tracing::warn!("Federation: presence announce to {} error: {}", self.config.peer.id, e);
false
}
/// Send a JSON message over the outgoing WS channel.
async fn send_json(&self, msg: serde_json::Value) -> bool {
let guard = self.outgoing.lock().await;
if let Some(ref tx) = *guard {
let json_str = serde_json::to_string(&msg).unwrap_or_default();
tx.send(json_str).is_ok()
} else {
false
}
}
}
/// Background task: periodically sync presence with peer.
pub async fn presence_sync_loop(
/// Background task: connect to peer's WS endpoint, send auth, then loop.
/// Handles reconnection on failure.
pub async fn outgoing_ws_loop(
handle: FederationHandle,
connections: crate::state::Connections,
state: crate::state::AppState,
) {
let interval = std::time::Duration::from_secs(handle.config.presence_interval_secs);
tracing::info!(
"Federation: presence sync started (peer={}, interval={}s)",
handle.config.peer.id, handle.config.presence_interval_secs
);
let ws_url = handle.config.peer.url
.replace("http://", "ws://")
.replace("https://", "wss://");
let ws_url = format!("{}/v1/federation/ws", ws_url);
loop {
// Collect local fingerprints
let fps: Vec<String> = {
let conns = connections.lock().await;
conns.keys().cloned().collect()
};
tracing::info!("Federation: connecting to peer {} at {}", handle.config.peer.id, ws_url);
// Announce to peer
let ok = handle.announce_presence(fps.clone()).await;
if ok {
tracing::debug!("Federation: announced {} fingerprints to {}", fps.len(), handle.config.peer.id);
}
match tokio_tungstenite::connect_async(&ws_url).await {
Ok((ws_stream, _)) => {
tracing::info!("Federation: connected to peer {}", handle.config.peer.id);
// Clear stale remote presence if peer hasn't responded
{
let mut rp = handle.remote_presence.lock().await;
if !rp.is_alive(handle.config.presence_interval_secs) && !rp.fingerprints.is_empty() {
tracing::warn!("Federation: peer {} stale — clearing remote presence ({} fps)",
handle.config.peer.id, rp.fingerprints.len());
rp.fingerprints.clear();
use futures_util::{SinkExt, StreamExt};
let (mut ws_tx, mut ws_rx) = ws_stream.split();
// Send auth as first message
let auth_msg = serde_json::json!({
"type": "auth",
"secret": handle.config.shared_secret,
"server_id": handle.config.server_id,
});
if ws_tx.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(&auth_msg).unwrap_or_default()
)).await.is_err() {
tracing::warn!("Federation: failed to send auth to peer");
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
continue;
}
// Set up outgoing channel
let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
{
let mut guard = handle.outgoing.lock().await;
*guard = Some(out_tx);
}
{
let mut rp = handle.remote_presence.lock().await;
rp.connected = true;
}
// Send initial presence
let fps: Vec<String> = {
let conns = state.connections.lock().await;
conns.keys().cloned().collect()
};
let _ = handle.push_presence(fps).await;
// Spawn task to forward outgoing channel to WS
let send_task = tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
if ws_tx.send(tokio_tungstenite::tungstenite::Message::Text(msg)).await.is_err() {
break;
}
}
});
// Read incoming messages from peer
while let Some(Ok(msg)) = ws_rx.next().await {
if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
handle_incoming_federation_msg(&text, &handle, &state).await;
}
}
// Connection lost
send_task.abort();
{
let mut guard = handle.outgoing.lock().await;
*guard = None;
}
{
let mut rp = handle.remote_presence.lock().await;
rp.connected = false;
rp.fingerprints.clear();
}
tracing::warn!("Federation: lost connection to peer {}, reconnecting...", handle.config.peer.id);
}
Err(e) => {
tracing::warn!("Federation: failed to connect to peer {}: {}", handle.config.peer.id, e);
}
}
tokio::time::sleep(interval).await;
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
}
/// Compute an auth token: SHA-256(secret || body). Simple HMAC-like construction.
pub fn compute_token(secret: &str, body: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(secret.as_bytes());
hasher.update(body);
hex::encode(hasher.finalize())
/// Process a single incoming JSON message from the federated peer WS.
async fn handle_incoming_federation_msg(
text: &str,
handle: &FederationHandle,
state: &crate::state::AppState,
) {
let parsed: serde_json::Value = match serde_json::from_str(text) {
Ok(v) => v,
Err(_) => return,
};
let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
match msg_type {
"presence" => {
let fingerprints: Vec<String> = parsed.get("fingerprints")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("?");
let count = fingerprints.len();
let mut rp = handle.remote_presence.lock().await;
rp.fingerprints = fingerprints.into_iter().collect();
rp.last_updated = chrono::Utc::now().timestamp();
tracing::debug!("Federation: received {} fingerprints from {}", count, server_id);
}
"forward" => {
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
let message_b64 = parsed.get("message").and_then(|v| v.as_str()).unwrap_or("");
let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("?");
if let Ok(message) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, message_b64) {
let delivered = state.push_to_client(to, &message).await;
if !delivered {
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
let _ = state.db.messages.insert(key.as_bytes(), message.as_slice());
tracing::info!("Federation: queued message from {} for offline {}", from_server, to);
} else {
tracing::debug!("Federation: delivered message from {} to {}", from_server, to);
}
}
}
_ => {
tracing::debug!("Federation: unknown message type '{}'", msg_type);
}
}
}
/// Verify an auth token.
pub fn verify_token(secret: &str, body: &[u8], token: &str) -> bool {
let expected = compute_token(secret, body);
// Constant-time comparison to prevent timing attacks
expected.len() == token.len() && expected.as_bytes().iter().zip(token.as_bytes()).all(|(a, b)| a == b)
}

View File

@@ -49,12 +49,12 @@ async fn main() -> anyhow::Result<()> {
state.federation = Some(handle);
}
// Spawn federation presence sync if enabled
if let Some(ref federation) = state.federation {
let handle = federation.clone();
let connections = state.connections.clone();
// Spawn federation outgoing WS connection if enabled
if let Some(ref fed) = state.federation {
let handle = fed.clone();
let fed_state = state.clone();
tokio::spawn(async move {
federation::presence_sync_loop(handle, connections).await;
federation::outgoing_ws_loop(handle, fed_state).await;
});
}

View File

@@ -1,124 +1,143 @@
//! Federation route handlers: receive presence updates and forwarded messages from peer server.
//! Federation route handlers: WS endpoint for peer servers + status.
use axum::{
body::Bytes,
extract::State,
http::{HeaderMap, StatusCode},
extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}},
response::IntoResponse,
routing::post,
routing::get,
Json, Router,
};
use futures_util::{SinkExt, StreamExt};
use crate::state::AppState;
pub fn routes() -> Router<AppState> {
Router::new()
.route("/federation/presence", post(receive_presence))
.route("/federation/forward", post(receive_forward))
.route("/federation/status", axum::routing::get(federation_status))
.route("/federation/ws", get(federation_ws_handler))
.route("/federation/status", get(federation_status))
}
/// Extract and validate the federation token from headers.
fn validate_request(state: &AppState, headers: &HeaderMap, body: &[u8]) -> Result<(), (StatusCode, String)> {
let federation = state.federation.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "federation not configured".to_string()))?;
let token = headers.get("x-federation-token")
.and_then(|v| v.to_str().ok())
.ok_or((StatusCode::UNAUTHORIZED, "missing X-Federation-Token header".to_string()))?;
if !crate::federation::verify_token(&federation.config.shared_secret, body, token) {
return Err((StatusCode::UNAUTHORIZED, "invalid federation token".to_string()));
}
Ok(())
}
/// Receive presence announcement from peer.
/// POST /v1/federation/presence
/// Body: { "server_id": "...", "fingerprints": [...], "timestamp": ... }
async fn receive_presence(
/// WebSocket endpoint for incoming peer server connections.
async fn federation_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
if let Err((status, msg)) = validate_request(&state, &headers, &body) {
return (status, Json(serde_json::json!({ "error": msg }))).into_response();
}
ws.on_upgrade(move |socket| handle_peer_ws(socket, state))
}
let parsed: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid JSON: {}", e) }))).into_response(),
/// Handle an incoming federation WS connection from the peer server.
async fn handle_peer_ws(socket: WebSocket, state: AppState) {
let (mut ws_tx, mut ws_rx) = socket.split();
// First message must be auth
let secret = match state.federation {
Some(ref f) => f.config.shared_secret.clone(),
None => {
tracing::warn!("Federation: WS connection rejected -- federation not configured");
return;
}
};
let fingerprints: Vec<String> = parsed.get("fingerprints")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
// Wait for auth message (5 second timeout)
let auth_msg = tokio::time::timeout(
std::time::Duration::from_secs(5),
ws_rx.next(),
).await;
let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("unknown");
let peer_id = match auth_msg {
Ok(Some(Ok(Message::Text(text)))) => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
let msg_secret = parsed.get("secret").and_then(|v| v.as_str()).unwrap_or("");
let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("unknown");
if msg_type != "auth" || msg_secret != secret {
tracing::warn!("Federation: WS auth failed from {}", server_id);
return;
}
tracing::info!("Federation: peer {} authenticated via WS", server_id);
server_id.to_string()
} else {
tracing::warn!("Federation: invalid auth JSON");
return;
}
}
_ => {
tracing::warn!("Federation: no auth message received within timeout");
return;
}
};
// Process incoming messages from the authenticated peer
while let Some(Ok(msg)) = ws_rx.next().await {
if let Message::Text(text) = msg {
let parsed: serde_json::Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
match msg_type {
"presence" => {
let fingerprints: Vec<String> = parsed.get("fingerprints")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
let count = fingerprints.len();
if let Some(ref federation) = state.federation {
let mut rp = federation.remote_presence.lock().await;
rp.fingerprints = fingerprints.into_iter().collect();
rp.last_updated = chrono::Utc::now().timestamp();
rp.connected = true;
}
tracing::debug!("Federation WS: {} announced {} fingerprints", peer_id, count);
// Send our presence back
if let Some(ref federation) = state.federation {
let fps: Vec<String> = {
let conns = state.connections.lock().await;
conns.keys().cloned().collect()
};
let reply = serde_json::json!({
"type": "presence",
"server_id": federation.config.server_id,
"fingerprints": fps,
});
let _ = ws_tx.send(Message::Text(serde_json::to_string(&reply).unwrap_or_default())).await;
}
}
"forward" => {
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
let message_b64 = parsed.get("message").and_then(|v| v.as_str()).unwrap_or("");
let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("?");
if let Ok(message) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, message_b64) {
let delivered = state.push_to_client(to, &message).await;
if !delivered {
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
let _ = state.db.messages.insert(key.as_bytes(), message.as_slice());
tracing::info!("Federation WS: queued from {} for offline {}", from_server, to);
} else {
tracing::debug!("Federation WS: delivered from {} to {}", from_server, to);
}
}
}
_ => {}
}
}
}
// Peer disconnected
if let Some(ref federation) = state.federation {
let mut rp = federation.remote_presence.lock().await;
let count = fingerprints.len();
rp.fingerprints = fingerprints.into_iter().collect();
rp.last_updated = chrono::Utc::now().timestamp();
tracing::debug!("Federation: received {} fingerprints from {}", count, server_id);
rp.connected = false;
rp.fingerprints.clear();
}
(StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response()
}
/// Receive a forwarded message from peer.
/// POST /v1/federation/forward
/// Body: { "to": "fingerprint", "message": "base64...", "from_server": "..." }
async fn receive_forward(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
if let Err((status, msg)) = validate_request(&state, &headers, &body) {
return (status, Json(serde_json::json!({ "error": msg }))).into_response();
}
let parsed: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid JSON: {}", e) }))).into_response(),
};
let to = match parsed.get("to").and_then(|v| v.as_str()) {
Some(fp) => fp.to_string(),
None => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "missing 'to' field" }))).into_response(),
};
let message_b64 = match parsed.get("message").and_then(|v| v.as_str()) {
Some(m) => m.to_string(),
None => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "missing 'message' field" }))).into_response(),
};
let message = match base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &message_b64) {
Ok(m) => m,
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid base64: {}", e) }))).into_response(),
};
let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("unknown");
// Try to deliver locally
let delivered = state.push_to_client(&to, &message).await;
if !delivered {
// Queue for later pickup
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
let _ = state.db.messages.insert(key.as_bytes(), message.as_slice());
tracing::info!("Federation: queued forwarded message from {} for offline user {}", from_server, to);
} else {
tracing::info!("Federation: delivered forwarded message from {} to {}", from_server, to);
}
(StatusCode::OK, Json(serde_json::json!({ "ok": true, "delivered": delivered }))).into_response()
tracing::info!("Federation WS: peer {} disconnected", peer_id);
}
/// Federation health status.
/// GET /v1/federation/status
async fn federation_status(
State(state): State<AppState>,
) -> Json<serde_json::Value> {
@@ -130,15 +149,13 @@ async fn federation_status(
"server_id": federation.config.server_id,
"peer_id": federation.config.peer.id,
"peer_url": federation.config.peer.url,
"peer_alive": rp.is_alive(federation.config.presence_interval_secs),
"peer_connected": rp.connected,
"remote_clients": rp.fingerprints.len(),
"last_sync": rp.last_updated,
}))
}
None => {
Json(serde_json::json!({
"enabled": false,
}))
Json(serde_json::json!({ "enabled": false }))
}
}
}