refactor: federation uses persistent WS instead of HTTP polling
- Server-to-server communication via WebSocket at /v1/federation/ws - Auth as first WS frame (shared secret), presence + forwards over same connection - Auto-reconnect every 3s on disconnect, instant presence push on connect - Replaces HTTP REST polling (no more 5s intervals, lower latency) - Removed dead HMAC helpers (auth is now direct secret comparison over WS) - Simplified ARCHITECTURE.md mermaid diagrams for Gitea rendering Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,124 +1,143 @@
|
||||
//! Federation route handlers: receive presence updates and forwarded messages from peer server.
|
||||
//! Federation route handlers: WS endpoint for peer servers + status.
|
||||
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}},
|
||||
response::IntoResponse,
|
||||
routing::post,
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub fn routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/federation/presence", post(receive_presence))
|
||||
.route("/federation/forward", post(receive_forward))
|
||||
.route("/federation/status", axum::routing::get(federation_status))
|
||||
.route("/federation/ws", get(federation_ws_handler))
|
||||
.route("/federation/status", get(federation_status))
|
||||
}
|
||||
|
||||
/// Extract and validate the federation token from headers.
|
||||
fn validate_request(state: &AppState, headers: &HeaderMap, body: &[u8]) -> Result<(), (StatusCode, String)> {
|
||||
let federation = state.federation.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "federation not configured".to_string()))?;
|
||||
|
||||
let token = headers.get("x-federation-token")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "missing X-Federation-Token header".to_string()))?;
|
||||
|
||||
if !crate::federation::verify_token(&federation.config.shared_secret, body, token) {
|
||||
return Err((StatusCode::UNAUTHORIZED, "invalid federation token".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive presence announcement from peer.
|
||||
/// POST /v1/federation/presence
|
||||
/// Body: { "server_id": "...", "fingerprints": [...], "timestamp": ... }
|
||||
async fn receive_presence(
|
||||
/// WebSocket endpoint for incoming peer server connections.
|
||||
async fn federation_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
if let Err((status, msg)) = validate_request(&state, &headers, &body) {
|
||||
return (status, Json(serde_json::json!({ "error": msg }))).into_response();
|
||||
}
|
||||
ws.on_upgrade(move |socket| handle_peer_ws(socket, state))
|
||||
}
|
||||
|
||||
let parsed: serde_json::Value = match serde_json::from_slice(&body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid JSON: {}", e) }))).into_response(),
|
||||
/// Handle an incoming federation WS connection from the peer server.
|
||||
async fn handle_peer_ws(socket: WebSocket, state: AppState) {
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// First message must be auth
|
||||
let secret = match state.federation {
|
||||
Some(ref f) => f.config.shared_secret.clone(),
|
||||
None => {
|
||||
tracing::warn!("Federation: WS connection rejected -- federation not configured");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let fingerprints: Vec<String> = parsed.get("fingerprints")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.unwrap_or_default();
|
||||
// Wait for auth message (5 second timeout)
|
||||
let auth_msg = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
ws_rx.next(),
|
||||
).await;
|
||||
|
||||
let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("unknown");
|
||||
let peer_id = match auth_msg {
|
||||
Ok(Some(Ok(Message::Text(text)))) => {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let msg_secret = parsed.get("secret").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let server_id = parsed.get("server_id").and_then(|v| v.as_str()).unwrap_or("unknown");
|
||||
|
||||
if msg_type != "auth" || msg_secret != secret {
|
||||
tracing::warn!("Federation: WS auth failed from {}", server_id);
|
||||
return;
|
||||
}
|
||||
tracing::info!("Federation: peer {} authenticated via WS", server_id);
|
||||
server_id.to_string()
|
||||
} else {
|
||||
tracing::warn!("Federation: invalid auth JSON");
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("Federation: no auth message received within timeout");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Process incoming messages from the authenticated peer
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
if let Message::Text(text) = msg {
|
||||
let parsed: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
match msg_type {
|
||||
"presence" => {
|
||||
let fingerprints: Vec<String> = parsed.get("fingerprints")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.unwrap_or_default();
|
||||
let count = fingerprints.len();
|
||||
|
||||
if let Some(ref federation) = state.federation {
|
||||
let mut rp = federation.remote_presence.lock().await;
|
||||
rp.fingerprints = fingerprints.into_iter().collect();
|
||||
rp.last_updated = chrono::Utc::now().timestamp();
|
||||
rp.connected = true;
|
||||
}
|
||||
tracing::debug!("Federation WS: {} announced {} fingerprints", peer_id, count);
|
||||
|
||||
// Send our presence back
|
||||
if let Some(ref federation) = state.federation {
|
||||
let fps: Vec<String> = {
|
||||
let conns = state.connections.lock().await;
|
||||
conns.keys().cloned().collect()
|
||||
};
|
||||
let reply = serde_json::json!({
|
||||
"type": "presence",
|
||||
"server_id": federation.config.server_id,
|
||||
"fingerprints": fps,
|
||||
});
|
||||
let _ = ws_tx.send(Message::Text(serde_json::to_string(&reply).unwrap_or_default())).await;
|
||||
}
|
||||
}
|
||||
"forward" => {
|
||||
let to = parsed.get("to").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let message_b64 = parsed.get("message").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("?");
|
||||
|
||||
if let Ok(message) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, message_b64) {
|
||||
let delivered = state.push_to_client(to, &message).await;
|
||||
if !delivered {
|
||||
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
|
||||
let _ = state.db.messages.insert(key.as_bytes(), message.as_slice());
|
||||
tracing::info!("Federation WS: queued from {} for offline {}", from_server, to);
|
||||
} else {
|
||||
tracing::debug!("Federation WS: delivered from {} to {}", from_server, to);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Peer disconnected
|
||||
if let Some(ref federation) = state.federation {
|
||||
let mut rp = federation.remote_presence.lock().await;
|
||||
let count = fingerprints.len();
|
||||
rp.fingerprints = fingerprints.into_iter().collect();
|
||||
rp.last_updated = chrono::Utc::now().timestamp();
|
||||
tracing::debug!("Federation: received {} fingerprints from {}", count, server_id);
|
||||
rp.connected = false;
|
||||
rp.fingerprints.clear();
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response()
|
||||
}
|
||||
|
||||
/// Receive a forwarded message from peer.
|
||||
/// POST /v1/federation/forward
|
||||
/// Body: { "to": "fingerprint", "message": "base64...", "from_server": "..." }
|
||||
async fn receive_forward(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
if let Err((status, msg)) = validate_request(&state, &headers, &body) {
|
||||
return (status, Json(serde_json::json!({ "error": msg }))).into_response();
|
||||
}
|
||||
|
||||
let parsed: serde_json::Value = match serde_json::from_slice(&body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid JSON: {}", e) }))).into_response(),
|
||||
};
|
||||
|
||||
let to = match parsed.get("to").and_then(|v| v.as_str()) {
|
||||
Some(fp) => fp.to_string(),
|
||||
None => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "missing 'to' field" }))).into_response(),
|
||||
};
|
||||
|
||||
let message_b64 = match parsed.get("message").and_then(|v| v.as_str()) {
|
||||
Some(m) => m.to_string(),
|
||||
None => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "missing 'message' field" }))).into_response(),
|
||||
};
|
||||
|
||||
let message = match base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &message_b64) {
|
||||
Ok(m) => m,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": format!("invalid base64: {}", e) }))).into_response(),
|
||||
};
|
||||
|
||||
let from_server = parsed.get("from_server").and_then(|v| v.as_str()).unwrap_or("unknown");
|
||||
|
||||
// Try to deliver locally
|
||||
let delivered = state.push_to_client(&to, &message).await;
|
||||
if !delivered {
|
||||
// Queue for later pickup
|
||||
let key = format!("queue:{}:{}", to, uuid::Uuid::new_v4());
|
||||
let _ = state.db.messages.insert(key.as_bytes(), message.as_slice());
|
||||
tracing::info!("Federation: queued forwarded message from {} for offline user {}", from_server, to);
|
||||
} else {
|
||||
tracing::info!("Federation: delivered forwarded message from {} to {}", from_server, to);
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(serde_json::json!({ "ok": true, "delivered": delivered }))).into_response()
|
||||
tracing::info!("Federation WS: peer {} disconnected", peer_id);
|
||||
}
|
||||
|
||||
/// Federation health status.
|
||||
/// GET /v1/federation/status
|
||||
async fn federation_status(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<serde_json::Value> {
|
||||
@@ -130,15 +149,13 @@ async fn federation_status(
|
||||
"server_id": federation.config.server_id,
|
||||
"peer_id": federation.config.peer.id,
|
||||
"peer_url": federation.config.peer.url,
|
||||
"peer_alive": rp.is_alive(federation.config.presence_interval_secs),
|
||||
"peer_connected": rp.connected,
|
||||
"remote_clients": rp.fingerprints.len(),
|
||||
"last_sync": rp.last_updated,
|
||||
}))
|
||||
}
|
||||
None => {
|
||||
Json(serde_json::json!({
|
||||
"enabled": false,
|
||||
}))
|
||||
Json(serde_json::json!({ "enabled": false }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user