Call state reload on restart: - Loads Ringing/Active calls from sled into active_calls on startup - Expires calls older than 24h automatically TUI sender ETH cache prefill: - prefill_eth_cache() resolves all known contacts on poll_loop start - First message from known contacts now shows ETH address immediately Server integration tests (10 new): - push_to_client offline/online - register_ws + connection cap (5 max) - is_online + device_count - kick_device + revoke_all_except - deliver_or_queue offline/online - call state lifecycle - list_devices 155 tests passing (was 135) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
415 lines
14 KiB
Rust
415 lines
14 KiB
Rust
use std::collections::{HashMap, HashSet, VecDeque};
|
|
use std::sync::Arc;
|
|
use tokio::sync::{Mutex, mpsc};
|
|
|
|
use crate::db::Database;
|
|
|
|
/// Maximum WebSocket connections per fingerprint (multi-device cap).
|
|
const MAX_WS_PER_FINGERPRINT: usize = 5;
|
|
|
|
/// Maximum number of message IDs to track for deduplication.
|
|
const DEDUP_CAPACITY: usize = 10_000;
|
|
|
|
/// Per-connection sender: messages are pushed here for instant delivery.
|
|
pub type WsSender = mpsc::UnboundedSender<Vec<u8>>;
|
|
|
|
/// Metadata for a single connected device.
|
|
#[derive(Clone)]
|
|
pub struct DeviceConnection {
|
|
pub device_id: String,
|
|
pub sender: WsSender,
|
|
pub connected_at: i64,
|
|
pub token: Option<String>,
|
|
}
|
|
|
|
/// Connected clients: fingerprint → list of device connections (multiple devices).
|
|
pub type Connections = Arc<Mutex<HashMap<String, Vec<DeviceConnection>>>>;
|
|
|
|
/// Bounded dedup tracker: FIFO eviction when capacity is exceeded.
|
|
#[derive(Clone)]
|
|
pub struct DedupTracker {
|
|
seen: Arc<std::sync::Mutex<HashSet<String>>>,
|
|
order: Arc<std::sync::Mutex<VecDeque<String>>>,
|
|
}
|
|
|
|
impl DedupTracker {
|
|
pub fn new() -> Self {
|
|
DedupTracker {
|
|
seen: Arc::new(std::sync::Mutex::new(HashSet::with_capacity(DEDUP_CAPACITY))),
|
|
order: Arc::new(std::sync::Mutex::new(VecDeque::with_capacity(DEDUP_CAPACITY))),
|
|
}
|
|
}
|
|
|
|
/// Returns `true` if this ID was already seen (i.e. it is a duplicate).
|
|
/// If new, inserts it and evicts the oldest if over capacity.
|
|
pub fn check_and_insert(&self, id: &str) -> bool {
|
|
let mut seen = self.seen.lock().unwrap();
|
|
if seen.contains(id) {
|
|
return true; // duplicate
|
|
}
|
|
let mut order = self.order.lock().unwrap();
|
|
if seen.len() >= DEDUP_CAPACITY {
|
|
if let Some(oldest) = order.pop_front() {
|
|
seen.remove(&oldest);
|
|
}
|
|
}
|
|
seen.insert(id.to_string());
|
|
order.push_back(id.to_string());
|
|
false // not a duplicate
|
|
}
|
|
}
|
|
|
|
/// Call lifecycle status.
|
|
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
|
pub enum CallStatus {
|
|
Ringing,
|
|
Active,
|
|
Ended,
|
|
}
|
|
|
|
/// Server-side state for an active or recently ended call.
|
|
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
|
pub struct CallState {
|
|
pub call_id: String,
|
|
pub caller_fp: String,
|
|
pub callee_fp: String,
|
|
pub group_name: Option<String>,
|
|
pub room_id: Option<String>,
|
|
pub status: CallStatus,
|
|
pub created_at: i64,
|
|
pub answered_at: Option<i64>,
|
|
pub ended_at: Option<i64>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub db: Arc<Database>,
|
|
pub connections: Connections,
|
|
pub dedup: DedupTracker,
|
|
pub active_calls: Arc<Mutex<HashMap<String, CallState>>>,
|
|
pub federation: Option<crate::federation::FederationHandle>,
|
|
pub bots_enabled: bool,
|
|
}
|
|
|
|
impl AppState {
|
|
pub fn new(data_dir: &str) -> anyhow::Result<Self> {
|
|
let db = Database::open(data_dir)?;
|
|
Ok(AppState {
|
|
db: Arc::new(db),
|
|
connections: Arc::new(Mutex::new(HashMap::new())),
|
|
dedup: DedupTracker::new(),
|
|
active_calls: Arc::new(Mutex::new(HashMap::new())),
|
|
federation: None,
|
|
bots_enabled: false,
|
|
})
|
|
}
|
|
|
|
/// Try to push a message to a connected client. Returns true if delivered.
|
|
pub async fn push_to_client(&self, fingerprint: &str, message: &[u8]) -> bool {
|
|
let conns = self.connections.lock().await;
|
|
if let Some(devices) = conns.get(fingerprint) {
|
|
let mut delivered = false;
|
|
for device in devices {
|
|
if device.sender.send(message.to_vec()).is_ok() {
|
|
delivered = true;
|
|
}
|
|
}
|
|
delivered
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Register a WS connection for a fingerprint.
|
|
///
|
|
/// Returns `None` if the per-fingerprint connection cap has been reached.
|
|
/// On success, returns the assigned device ID and a receiver for push messages.
|
|
pub async fn register_ws(&self, fingerprint: &str, token: Option<String>) -> Option<(String, mpsc::UnboundedReceiver<Vec<u8>>)> {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
let device_id = uuid::Uuid::new_v4().to_string()[..8].to_string();
|
|
let mut conns = self.connections.lock().await;
|
|
let entry = conns.entry(fingerprint.to_string()).or_default();
|
|
|
|
// Clean up closed connections first
|
|
entry.retain(|d| !d.sender.is_closed());
|
|
|
|
if entry.len() >= MAX_WS_PER_FINGERPRINT {
|
|
tracing::warn!(
|
|
"WS connection cap reached for {} ({} connections)",
|
|
fingerprint,
|
|
entry.len()
|
|
);
|
|
return None;
|
|
}
|
|
|
|
entry.push(DeviceConnection {
|
|
device_id: device_id.clone(),
|
|
sender: tx,
|
|
connected_at: chrono::Utc::now().timestamp(),
|
|
token,
|
|
});
|
|
tracing::info!(
|
|
"WS registered for {} device={} ({} total)",
|
|
fingerprint,
|
|
device_id,
|
|
conns.values().map(|v| v.len()).sum::<usize>()
|
|
);
|
|
Some((device_id, rx))
|
|
}
|
|
|
|
/// Unregister a WS connection.
|
|
#[allow(dead_code)]
|
|
pub async fn unregister_ws(&self, fingerprint: &str, sender: &WsSender) {
|
|
let mut conns = self.connections.lock().await;
|
|
if let Some(devices) = conns.get_mut(fingerprint) {
|
|
devices.retain(|d| !d.sender.same_channel(sender));
|
|
if devices.is_empty() {
|
|
conns.remove(fingerprint);
|
|
}
|
|
}
|
|
tracing::info!("WS unregistered for {}", fingerprint);
|
|
}
|
|
|
|
/// Try to deliver a message: local push → federation forward → DB queue.
|
|
/// Returns true if delivered instantly (local or remote).
|
|
pub async fn deliver_or_queue(&self, to_fp: &str, message: &[u8]) -> bool {
|
|
// BotFather: intercept messages to @botfather
|
|
if self.bots_enabled && to_fp == "00000000000000000b0ffa00e000000f" {
|
|
// Extract sender from message
|
|
if let Ok(msg) = serde_json::from_slice::<serde_json::Value>(message) {
|
|
let from = msg.get("from").and_then(|v| v.as_str()).unwrap_or("");
|
|
if !from.is_empty() {
|
|
if crate::botfather::handle_botfather_message(self, from, message).await {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 1. Try local WebSocket push
|
|
if self.push_to_client(to_fp, message).await {
|
|
return true;
|
|
}
|
|
|
|
// 2. Try federation forward
|
|
if let Some(ref federation) = self.federation {
|
|
if federation.is_remote(to_fp).await {
|
|
if federation.forward_message(to_fp, message).await {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Queue in local DB
|
|
let key = format!("queue:{}:{}", to_fp, uuid::Uuid::new_v4());
|
|
let _ = self.db.messages.insert(key.as_bytes(), message);
|
|
|
|
// 4. Try bot webhook delivery (async, does not block the caller)
|
|
{
|
|
let state = self.clone();
|
|
let fp = to_fp.to_string();
|
|
let queue_key = key.clone();
|
|
let msg = message.to_vec();
|
|
tokio::spawn(async move {
|
|
if crate::routes::bot::try_bot_webhook(&state, &fp, &msg).await {
|
|
// Webhook accepted -- remove from offline queue
|
|
let _ = state.db.messages.remove(queue_key.as_bytes());
|
|
}
|
|
});
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
/// Check if a fingerprint has any active WS connections.
|
|
pub async fn is_online(&self, fingerprint: &str) -> bool {
|
|
let conns = self.connections.lock().await;
|
|
conns.get(fingerprint).map(|d| !d.is_empty()).unwrap_or(false)
|
|
}
|
|
|
|
/// Count active WS connections for a fingerprint (multi-device).
|
|
pub async fn device_count(&self, fingerprint: &str) -> usize {
|
|
let conns = self.connections.lock().await;
|
|
conns.get(fingerprint).map(|d| d.len()).unwrap_or(0)
|
|
}
|
|
|
|
/// List devices for a fingerprint with metadata.
|
|
pub async fn list_devices(&self, fingerprint: &str) -> Vec<(String, i64)> {
|
|
let conns = self.connections.lock().await;
|
|
conns.get(fingerprint)
|
|
.map(|devices| devices.iter().map(|d| (d.device_id.clone(), d.connected_at)).collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Kick a specific device by ID. Returns true if found and kicked.
|
|
pub async fn kick_device(&self, fingerprint: &str, device_id: &str) -> bool {
|
|
let mut conns = self.connections.lock().await;
|
|
if let Some(devices) = conns.get_mut(fingerprint) {
|
|
let before = devices.len();
|
|
devices.retain(|d| d.device_id != device_id);
|
|
let kicked = devices.len() < before;
|
|
if devices.is_empty() {
|
|
conns.remove(fingerprint);
|
|
}
|
|
kicked
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Revoke all connections for a fingerprint except one device_id.
|
|
pub async fn revoke_all_except(&self, fingerprint: &str, keep_device_id: &str) -> usize {
|
|
let mut conns = self.connections.lock().await;
|
|
if let Some(devices) = conns.get_mut(fingerprint) {
|
|
let before = devices.len();
|
|
devices.retain(|d| d.device_id == keep_device_id);
|
|
let removed = before - devices.len();
|
|
if devices.is_empty() {
|
|
conns.remove(fingerprint);
|
|
}
|
|
removed
|
|
} else {
|
|
0
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn test_state() -> AppState {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
AppState::new(dir.path().to_str().unwrap()).unwrap()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn push_to_client_returns_false_when_offline() {
|
|
let state = test_state();
|
|
assert!(!state.push_to_client("abc123", b"hello").await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn register_ws_and_push() {
|
|
let state = test_state();
|
|
let (_, mut rx) = state.register_ws("test_fp", None).await.unwrap();
|
|
|
|
assert!(state.push_to_client("test_fp", b"hello").await);
|
|
let msg = rx.recv().await.unwrap();
|
|
assert_eq!(msg, b"hello");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn ws_connection_cap() {
|
|
let state = test_state();
|
|
// Hold receivers so senders stay open (register_ws prunes closed senders).
|
|
let mut _holders = Vec::new();
|
|
for i in 0..5 {
|
|
let res = state.register_ws("same_fp", None).await;
|
|
assert!(res.is_some(), "connection {} should succeed", i);
|
|
_holders.push(res.unwrap());
|
|
}
|
|
// 6th should fail
|
|
assert!(state.register_ws("same_fp", None).await.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn is_online_and_device_count() {
|
|
let state = test_state();
|
|
assert!(!state.is_online("fp1").await);
|
|
assert_eq!(state.device_count("fp1").await, 0);
|
|
|
|
// Must hold receivers so the senders are not marked as closed.
|
|
let _r1 = state.register_ws("fp1", None).await;
|
|
assert!(state.is_online("fp1").await);
|
|
assert_eq!(state.device_count("fp1").await, 1);
|
|
|
|
let _r2 = state.register_ws("fp1", None).await;
|
|
assert_eq!(state.device_count("fp1").await, 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn kick_device() {
|
|
let state = test_state();
|
|
let (device_id, _) = state.register_ws("fp1", None).await.unwrap();
|
|
|
|
assert!(state.kick_device("fp1", &device_id).await);
|
|
assert!(!state.is_online("fp1").await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn revoke_all_except() {
|
|
let state = test_state();
|
|
let (id1, _rx1) = state.register_ws("fp1", None).await.unwrap();
|
|
let (_id2, _rx2) = state.register_ws("fp1", None).await.unwrap();
|
|
let (_id3, _rx3) = state.register_ws("fp1", None).await.unwrap();
|
|
|
|
let removed = state.revoke_all_except("fp1", &id1).await;
|
|
assert_eq!(removed, 2);
|
|
assert_eq!(state.device_count("fp1").await, 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn deliver_or_queue_offline() {
|
|
let state = test_state();
|
|
// No WS connected -- should queue
|
|
let delivered = state.deliver_or_queue("offline_fp", b"test message").await;
|
|
assert!(!delivered);
|
|
|
|
// Check message was queued in DB
|
|
let prefix = "queue:offline_fp";
|
|
let count = state.db.messages.scan_prefix(prefix.as_bytes()).count();
|
|
assert_eq!(count, 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn deliver_or_queue_online() {
|
|
let state = test_state();
|
|
let (_, mut rx) = state.register_ws("online_fp", None).await.unwrap();
|
|
|
|
let delivered = state.deliver_or_queue("online_fp", b"instant").await;
|
|
assert!(delivered);
|
|
|
|
let msg = rx.recv().await.unwrap();
|
|
assert_eq!(msg, b"instant");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn call_state_lifecycle() {
|
|
let state = test_state();
|
|
|
|
let call = CallState {
|
|
call_id: "call-001".into(),
|
|
caller_fp: "alice".into(),
|
|
callee_fp: "bob".into(),
|
|
group_name: None,
|
|
room_id: None,
|
|
status: CallStatus::Ringing,
|
|
created_at: chrono::Utc::now().timestamp(),
|
|
answered_at: None,
|
|
ended_at: None,
|
|
};
|
|
|
|
state.active_calls.lock().await.insert("call-001".into(), call);
|
|
assert_eq!(state.active_calls.lock().await.len(), 1);
|
|
|
|
// End the call
|
|
if let Some(mut c) = state.active_calls.lock().await.remove("call-001") {
|
|
c.status = CallStatus::Ended;
|
|
c.ended_at = Some(chrono::Utc::now().timestamp());
|
|
let _ = state.db.calls.insert(b"call-001", serde_json::to_vec(&c).unwrap());
|
|
}
|
|
assert_eq!(state.active_calls.lock().await.len(), 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn list_devices() {
|
|
let state = test_state();
|
|
let _r1 = state.register_ws("fp1", None).await;
|
|
let _r2 = state.register_ws("fp1", None).await;
|
|
|
|
let devices = state.list_devices("fp1").await;
|
|
assert_eq!(devices.len(), 2);
|
|
}
|
|
}
|