use axum::{ extract::{Path, State}, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use crate::state::AppState; pub fn routes() -> Router { Router::new() .route("/keys/register", post(register_keys)) .route("/keys/replenish", post(replenish_otpks)) .route("/keys/list", get(list_keys)) .route("/keys/:fingerprint", get(get_bundle)) .route("/keys/:fingerprint/otpk-count", get(otpk_count)) .route("/keys/:fingerprint/devices", get(list_devices)) } /// Debug endpoint: list all registered fingerprints. async fn list_keys(State(state): State) -> Json { let keys: Vec = state .db .keys .iter() .filter_map(|item| { item.ok() .and_then(|(k, _)| String::from_utf8(k.to_vec()).ok()) }) .collect(); tracing::info!("Listed {} registered keys", keys.len()); Json(serde_json::json!({ "keys": keys, "count": keys.len() })) } /// Normalize fingerprint: strip colons, lowercase. fn normalize_fp(fp: &str) -> String { fp.chars() .filter(|c| c.is_ascii_hexdigit()) .collect::() .to_lowercase() } #[derive(Deserialize)] struct RegisterRequest { fingerprint: String, #[serde(default)] device_id: Option, bundle: Vec, } #[derive(Serialize)] struct RegisterResponse { ok: bool, } async fn register_keys( _auth: crate::auth_middleware::AuthFingerprint, State(state): State, Json(req): Json, ) -> Json { let fp = normalize_fp(&req.fingerprint); let device_id = req.device_id.unwrap_or_else(|| "default".to_string()); // Store bundle keyed by fingerprint (primary, used for lookup) let _ = state.db.keys.insert(fp.as_bytes(), req.bundle.clone()); // Also store per-device: device:: → bundle let device_key = format!("device:{}:{}", fp, device_id); let _ = state.db.keys.insert(device_key.as_bytes(), req.bundle); tracing::info!("Registered bundle for {} (device: {})", fp, device_id); Json(RegisterResponse { ok: true }) } async fn get_bundle( State(state): State, Path(fingerprint): Path, ) -> Result, axum::http::StatusCode> { let key = normalize_fp(&fingerprint); tracing::info!("get_bundle: raw path='{}', normalized='{}'", fingerprint, key); // Debug: list what's in the DB let all_keys: Vec = state.db.keys.iter() .filter_map(|r| r.ok().and_then(|(k, _)| String::from_utf8(k.to_vec()).ok())) .collect(); tracing::info!("get_bundle: DB contains {} keys: {:?}", all_keys.len(), all_keys); match state.db.keys.get(key.as_bytes()) { Ok(Some(data)) => { tracing::info!("get_bundle: FOUND {} bytes for {}", data.len(), key); Ok(Json(serde_json::json!({ "fingerprint": fingerprint, "bundle": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &data), }))) } Ok(None) => { tracing::warn!("get_bundle: NOT FOUND for key '{}'", key); Err(axum::http::StatusCode::NOT_FOUND) } Err(e) => { tracing::error!("get_bundle: DB error: {}", e); Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR) } } } /// Check how many one-time pre-keys remain for a fingerprint. async fn otpk_count( State(state): State, Path(fingerprint): Path, ) -> Json { let fp = normalize_fp(&fingerprint); let prefix = format!("otpk:{}:", fp); let count = state.db.keys.scan_prefix(prefix.as_bytes()).count(); Json(serde_json::json!({ "fingerprint": fp, "otpk_count": count })) } #[derive(Deserialize)] struct ReplenishRequest { fingerprint: String, /// One-time pre-keys: list of {id, public_key_hex} otpks: Vec, } #[derive(Deserialize)] struct OtpkEntry { id: u32, public_key: String, // hex-encoded 32-byte X25519 public key } /// Upload additional one-time pre-keys. async fn replenish_otpks( _auth: crate::auth_middleware::AuthFingerprint, State(state): State, Json(req): Json, ) -> Json { let fp = normalize_fp(&req.fingerprint); let mut stored = 0; for otpk in &req.otpks { let key = format!("otpk:{}:{}", fp, otpk.id); let _ = state.db.keys.insert(key.as_bytes(), otpk.public_key.as_bytes()); stored += 1; } let prefix = format!("otpk:{}:", fp); let total = state.db.keys.scan_prefix(prefix.as_bytes()).count(); tracing::info!("Replenished {} OTPKs for {} (total: {})", stored, fp, total); Json(serde_json::json!({ "ok": true, "stored": stored, "total": total })) } /// List all registered devices for a fingerprint. async fn list_devices( State(state): State, Path(fingerprint): Path, ) -> Json { let fp = normalize_fp(&fingerprint); let prefix = format!("device:{}:", fp); let devices: Vec = state.db.keys.scan_prefix(prefix.as_bytes()) .filter_map(|item| { item.ok().and_then(|(k, _)| { let key_str = String::from_utf8_lossy(&k).to_string(); // key format: device:: key_str.rsplit(':').next().map(|s| s.to_string()) }) }) .collect(); Json(serde_json::json!({ "fingerprint": fp, "devices": devices, "count": devices.len() })) }