Merge branch 'experimental-ui'

Covers T1–T6 task series plus audit remediations:
- Full video pipeline: AV1/H264/H265 codec factory, VideoScorer, simulcast,
  keyframe cache, PLI suppression, NACK, VideoReassembler
- E2E AEAD: EncryptingTransport wraps all media; nonce from MediaHeader.seq
- Camera capture (getUserMedia) + remote video strip (canvas)
- Android Tauri audio pipeline: Oboe config, threading, spawn_blocking fixes
- Relay: audio scorer, video scorer, response policy, conformance, federation
- Protocol: SignalMessage version byte, AV1 codec negotiation, quality profiles
- 825 passing tests across 41 suites

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Siavash Sameni
2026-05-25 15:30:45 +04:00
343 changed files with 51589 additions and 5804 deletions

14
.gitleaks.toml Normal file
View File

@@ -0,0 +1,14 @@
[extend]
useDefault = true
[[allowlists]]
description = "Pre-existing historical findings already on fj/main and github/main. The two PASTE_AUTH tokens in scripts/build.sh and scripts/build-linux-notify.sh are real — rotate if those endpoints still authenticate; this allowlist only silences the pre-push hook, it does not remove the exposure."
commits = [
# wzp-crypto module doc: false positive on "SHA-256(Ed25519 pub)[:16]"
"51e893590c1b9fa49e9f6ae5c96c26deb58f353b",
# build.sh PASTE_AUTH (paste.tbs.amn.gg)
"bd6733b2e5d76b5259020f1c30a5223a9773b6aa",
# build-linux-notify Authorization header (paste.dk.manko.yoga)
"6d776097c83bc6fbe3f3565e080513d8af93b550",
"7751439e2bca9eacf2c30929c8124a4eb6136df2",
]

1178
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ members = [
"crates/wzp-web", "crates/wzp-web",
"crates/wzp-android", "crates/wzp-android",
"crates/wzp-native", "crates/wzp-native",
"crates/wzp-video",
"desktop/src-tauri", "desktop/src-tauri",
] ]

1
android.sh Normal file
View File

@@ -0,0 +1 @@
./scripts/android-build-async.sh --init

View File

@@ -28,6 +28,7 @@ libc = "0.2"
jni = { version = "0.21", default-features = false } jni = { version = "0.21", default-features = false }
rand = { workspace = true } rand = { workspace = true }
rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls = { version = "0.23", default-features = false, features = ["ring"] }
[target.'cfg(target_os = "android")'.dependencies]
tracing-android = "0.2" tracing-android = "0.2"
[build-dependencies] [build-dependencies]

View File

@@ -65,9 +65,8 @@ fn main() {
} else { } else {
"aarch64-linux-android" "aarch64-linux-android"
}; };
let lib_dir = format!( let lib_dir =
"{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}" format!("{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}");
);
println!("cargo:rustc-link-search=native={lib_dir}"); println!("cargo:rustc-link-search=native={lib_dir}");
// Copy libc++_shared.so to the jniLibs directory // Copy libc++_shared.so to the jniLibs directory
@@ -82,9 +81,7 @@ fn main() {
}; };
// Try to copy to the Gradle jniLibs directory // Try to copy to the Gradle jniLibs directory
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default(); let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
let jni_dir = format!( let jni_dir = format!("{manifest}/../../android/app/src/main/jniLibs/{jni_abi}");
"{manifest}/../../android/app/src/main/jniLibs/{jni_abi}"
);
if let Ok(_) = std::fs::create_dir_all(&jni_dir) { if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so")); let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}"); println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
@@ -127,7 +124,12 @@ fn fetch_oboe() -> Option<PathBuf> {
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let oboe_dir = out_dir.join("oboe"); let oboe_dir = out_dir.join("oboe");
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() { if oboe_dir
.join("include")
.join("oboe")
.join("Oboe.h")
.exists()
{
return Some(oboe_dir); return Some(oboe_dir);
} }
@@ -143,7 +145,12 @@ fn fetch_oboe() -> Option<PathBuf> {
match status { match status {
Ok(s) if s.success() => { Ok(s) if s.success() => {
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() { if oboe_dir
.join("include")
.join("oboe")
.join("Oboe.h")
.exists()
{
Some(oboe_dir) Some(oboe_dir)
} else { } else {
None None

View File

@@ -326,7 +326,10 @@ pub fn pin_to_big_core() {
&set, &set,
); );
if ret != 0 { if ret != 0 {
warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error()); warn!(
"sched_setaffinity failed: {}",
std::io::Error::last_os_error()
);
} else { } else {
info!(start, num_cpus, "pinned to big cores"); info!(start, num_cpus, "pinned to big cores");
} }

View File

@@ -77,7 +77,8 @@ impl AudioRing {
} }
} }
self.write_pos.store(w.wrapping_add(count), Ordering::Release); self.write_pos
.store(w.wrapping_add(count), Ordering::Release);
count count
} }
@@ -112,7 +113,8 @@ impl AudioRing {
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) }; out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
} }
self.read_pos.store(r.wrapping_add(count), Ordering::Release); self.read_pos
.store(r.wrapping_add(count), Ordering::Release);
count count
} }

View File

@@ -22,7 +22,8 @@ use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::{ use wzp_proto::{
AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder, AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
MediaHeader, MediaPacket, MediaTransport, QualityController, QualityProfile, SignalMessage, MediaHeader, MediaPacket, MediaTransport, MediaType, QualityController, QualityProfile,
SignalMessage, default_signal_version,
}; };
use crate::audio_ring::AudioRing; use crate::audio_ring::AudioRing;
@@ -46,7 +47,11 @@ const PROFILES: [QualityProfile; 6] = [
]; ];
fn profile_to_index(p: &QualityProfile) -> u8 { fn profile_to_index(p: &QualityProfile) -> u8 {
PROFILES.iter().position(|pp| pp.codec == p.codec).map(|i| i as u8).unwrap_or(3) PROFILES
.iter()
.position(|pp| pp.codec == p.codec)
.map(|i| i as u8)
.unwrap_or(3)
} }
fn index_to_profile(idx: u8) -> Option<QualityProfile> { fn index_to_profile(idx: u8) -> Option<QualityProfile> {
@@ -149,9 +154,10 @@ impl WzpEngine {
.enable_all() .enable_all()
.build()?; .build()?;
let relay_addr: SocketAddr = config.relay_addr.parse().map_err(|e| { let relay_addr: SocketAddr = config
anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr) .relay_addr
})?; .parse()
.map_err(|e| anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr))?;
let room = config.room.clone(); let room = config.room.clone();
let identity_seed = config.identity_seed; let identity_seed = config.identity_seed;
@@ -165,7 +171,16 @@ impl WzpEngine {
let state_clone = state.clone(); let state_clone = state.clone();
runtime.block_on(async move { runtime.block_on(async move {
if let Err(e) = run_call(relay_addr, &room, &identity_seed, profile, auto_profile, alias.as_deref(), state_clone).await if let Err(e) = run_call(
relay_addr,
&room,
&identity_seed,
profile,
auto_profile,
alias.as_deref(),
state_clone,
)
.await
{ {
error!("call failed: {e}"); error!("call failed: {e}");
} }
@@ -233,16 +248,21 @@ impl WzpEngine {
let server_fp = conn let server_fp = conn
.peer_identity() .peer_identity()
.and_then(|id| id.downcast::<Vec<rustls::pki_types::CertificateDer>>().ok()) .and_then(|id| id.downcast::<Vec<rustls::pki_types::CertificateDer>>().ok())
.and_then(|certs| certs.first().map(|c| { .and_then(|certs| {
certs.first().map(|c| {
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new(); let mut h = std::collections::hash_map::DefaultHasher::new();
c.as_ref().hash(&mut h); c.as_ref().hash(&mut h);
format!("{:016x}", h.finish()) format!("{:016x}", h.finish())
})) })
})
.unwrap_or_default(); .unwrap_or_default();
conn.close(0u32.into(), b"ping"); conn.close(0u32.into(), b"ping");
Ok::<_, anyhow::Error>(format!(r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#, rtt_ms, server_fp)) Ok::<_, anyhow::Error>(format!(
r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#,
rtt_ms, server_fp
))
}); });
// Shutdown runtime cleanly with timeout // Shutdown runtime cleanly with timeout
@@ -301,11 +321,12 @@ impl WzpEngine {
// Auth if token provided // Auth if token provided
if let Some(ref tok) = token { if let Some(ref tok) = token {
let _ = transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await; let _ = transport.send_signal(&SignalMessage::AuthToken { version: default_signal_version(), token: tok.clone() }).await;
} }
// Register presence // Register presence
let _ = transport.send_signal(&SignalMessage::RegisterPresence { let _ = transport.send_signal(&SignalMessage::RegisterPresence {
version: default_signal_version(),
identity_pub, identity_pub,
signature: vec![], signature: vec![],
alias: alias.clone(), alias: alias.clone(),
@@ -330,7 +351,7 @@ impl WzpEngine {
break; break;
} }
match transport.recv_signal().await { match transport.recv_signal().await {
Ok(Some(SignalMessage::CallRinging { call_id })) => { Ok(Some(SignalMessage::CallRinging { call_id, ..})) => {
info!(call_id = %call_id, "signal: ringing"); info!(call_id = %call_id, "signal: ringing");
let mut stats = signal_state.stats.lock().unwrap(); let mut stats = signal_state.stats.lock().unwrap();
stats.state = crate::stats::CallState::Ringing; stats.state = crate::stats::CallState::Ringing;
@@ -392,7 +413,11 @@ impl WzpEngine {
} }
/// Answer an incoming direct call. /// Answer an incoming direct call.
pub fn answer_call(&self, call_id: &str, mode: wzp_proto::CallAcceptMode) -> Result<(), anyhow::Error> { pub fn answer_call(
&self,
call_id: &str,
mode: wzp_proto::CallAcceptMode,
) -> Result<(), anyhow::Error> {
let _ = self.state.command_tx.send(EngineCommand::AnswerCall { let _ = self.state.command_tx.send(EngineCommand::AnswerCall {
call_id: call_id.to_string(), call_id: call_id.to_string(),
accept_mode: mode, accept_mode: mode,
@@ -412,7 +437,9 @@ impl WzpEngine {
/// Stores the type atomically; the recv task polls it on each packet. /// Stores the type atomically; the recv task polls it on each packet.
pub fn on_network_changed(&self, network_type: u8, bandwidth_kbps: u32) { pub fn on_network_changed(&self, network_type: u8, bandwidth_kbps: u32) {
info!(network_type, bandwidth_kbps, "on_network_changed"); info!(network_type, bandwidth_kbps, "on_network_changed");
self.state.pending_network_type.store(network_type, Ordering::Release); self.state
.pending_network_type
.store(network_type, Ordering::Release);
} }
pub fn get_stats(&self) -> CallStats { pub fn get_stats(&self) -> CallStats {
@@ -496,6 +523,7 @@ async fn run_call(
let signature = kx.sign(&sign_data); let signature = kx.sign(&sign_data);
let offer = SignalMessage::CallOffer { let offer = SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature, signature,
@@ -508,6 +536,9 @@ async fn run_call(
QualityProfile::CATASTROPHIC, QualityProfile::CATASTROPHIC,
], ],
alias: alias.map(|s| s.to_string()), alias: alias.map(|s| s.to_string()),
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}; };
transport.send_signal(&offer).await?; transport.send_signal(&offer).await?;
info!("CallOffer sent, waiting for CallAnswer..."); info!("CallOffer sent, waiting for CallAnswer...");
@@ -518,12 +549,16 @@ async fn run_call(
.ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?; .ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?;
let (relay_ephemeral_pub, chosen_profile) = match answer { let (relay_ephemeral_pub, chosen_profile) = match answer {
SignalMessage::CallAnswer { ephemeral_pub, chosen_profile, .. } => (ephemeral_pub, chosen_profile), SignalMessage::CallAnswer {
ephemeral_pub,
chosen_profile,
..
} => (ephemeral_pub, chosen_profile),
other => { other => {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"expected CallAnswer, got {:?}", "expected CallAnswer, got {:?}",
std::mem::discriminant(&other) std::mem::discriminant(&other)
)) ));
} }
}; };
@@ -574,7 +609,7 @@ async fn run_call(
stats.auto_mode = auto_profile; stats.auto_mode = auto_profile;
} }
let seq = AtomicU16::new(0); let seq = AtomicU32::new(0);
let ts = AtomicU32::new(0); let ts = AtomicU32::new(0);
let transport_recv = transport.clone(); let transport_recv = transport.clone();
@@ -700,17 +735,15 @@ async fn run_call(
let source_pkt = MediaPacket { let source_pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: MediaHeader::VERSION,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: current_profile.codec, codec_id: current_profile.codec,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: hdr_fec_ratio, fec_ratio: hdr_fec_ratio,
seq: s, seq: s,
timestamp: t, timestamp: t,
fec_block: hdr_fec_block, fec_block: ((hdr_fec_symbol as u16) << 8) | (hdr_fec_block as u16),
fec_symbol: hdr_fec_symbol,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::copy_from_slice(encoded), payload: Bytes::copy_from_slice(encoded),
quality_report: None, quality_report: None,
@@ -725,9 +758,7 @@ async fn run_call(
if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 { if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 {
warn!( warn!(
seq = s, seq = s,
send_errors, send_errors, frames_dropped, "send_media error (dropping packet): {e}"
frames_dropped,
"send_media error (dropping packet): {e}"
); );
last_send_error_log = Instant::now(); last_send_error_log = Instant::now();
} }
@@ -756,19 +787,17 @@ async fn run_call(
let rs = seq.fetch_add(1, Ordering::Relaxed); let rs = seq.fetch_add(1, Ordering::Relaxed);
let repair_pkt = MediaPacket { let repair_pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: MediaHeader::VERSION,
is_repair: true, flags: MediaHeader::FLAG_REPAIR,
media_type: MediaType::Audio,
codec_id: current_profile.codec, codec_id: current_profile.codec,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: MediaHeader::encode_fec_ratio( fec_ratio: MediaHeader::encode_fec_ratio(
current_profile.fec_ratio, current_profile.fec_ratio,
), ),
seq: rs, seq: rs,
timestamp: t, timestamp: t,
fec_block: block_id, fec_block: (sym_idx << 8) | (block_id as u16),
fec_symbol: sym_idx,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(repair_data), payload: Bytes::from(repair_data),
quality_report: None, quality_report: None,
@@ -820,7 +849,11 @@ async fn run_call(
avg_total_us = avg(t_agc_us + t_opus_us + t_fec_us + t_send_us), avg_total_us = avg(t_agc_us + t_opus_us + t_fec_us + t_send_us),
"send stats" "send stats"
); );
t_agc_us = 0; t_opus_us = 0; t_fec_us = 0; t_send_us = 0; t_frames = 0; t_agc_us = 0;
t_opus_us = 0;
t_fec_us = 0;
t_send_us = 0;
t_frames = 0;
last_stats_log = Instant::now(); last_stats_log = Instant::now();
} }
} }
@@ -849,14 +882,11 @@ async fn run_call(
// when a packet arrives with seq > expected_seq, the frames in // when a packet arrives with seq > expected_seq, the frames in
// between are missing and we attempt to reconstruct them via // between are missing and we attempt to reconstruct them via
// DRED before decoding the newly-arrived packet. // DRED before decoding the newly-arrived packet.
let mut dred_decoder = let mut dred_decoder = DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
DredDecoderHandle::new().expect("opus_dred_decoder_create failed"); let mut dred_parse_scratch = DredState::new().expect("opus_dred_alloc failed (scratch)");
let mut dred_parse_scratch = let mut last_good_dred = DredState::new().expect("opus_dred_alloc failed (good state)");
DredState::new().expect("opus_dred_alloc failed (scratch)"); let mut last_good_dred_seq: Option<u32> = None;
let mut last_good_dred = let mut expected_seq: Option<u32> = None;
DredState::new().expect("opus_dred_alloc failed (good state)");
let mut last_good_dred_seq: Option<u16> = None;
let mut expected_seq: Option<u16> = None;
let mut dred_reconstructions: u64 = 0; let mut dred_reconstructions: u64 = 0;
let mut classical_plc_invocations: u64 = 0; let mut classical_plc_invocations: u64 = 0;
@@ -877,14 +907,16 @@ async fn run_call(
warn!( warn!(
recv_gap_ms, recv_gap_ms,
seq = pkt.header.seq, seq = pkt.header.seq,
is_repair = pkt.header.is_repair, is_repair = pkt.header.is_repair(),
"large recv gap — possible network stall" "large recv gap — possible network stall"
); );
} }
// Check for network transport change from ConnectivityManager // Check for network transport change from ConnectivityManager
{ {
let net = state.pending_network_type.swap(PROFILE_NO_CHANGE, Ordering::Acquire); let net = state
.pending_network_type
.swap(PROFILE_NO_CHANGE, Ordering::Acquire);
if net != PROFILE_NO_CHANGE { if net != PROFILE_NO_CHANGE {
use wzp_proto::NetworkContext; use wzp_proto::NetworkContext;
let ctx = match net { let ctx = match net {
@@ -916,9 +948,9 @@ async fn run_call(
} }
} }
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
let pkt_block = pkt.header.fec_block; let pkt_block = pkt.header.fec_block;
let pkt_symbol = pkt.header.fec_symbol; let pkt_symbol = (pkt.header.fec_block >> 8) as u16;
let pkt_is_opus = pkt.header.codec_id.is_opus(); let pkt_is_opus = pkt.header.codec_id.is_opus();
// Phase 2: Opus packets bypass RaptorQ entirely — DRED // Phase 2: Opus packets bypass RaptorQ entirely — DRED
@@ -927,12 +959,7 @@ async fn run_call(
// would accumulate block_id=0 duplicates that never // would accumulate block_id=0 duplicates that never
// decode. Codec2 packets still feed RaptorQ. // decode. Codec2 packets still feed RaptorQ.
if !pkt_is_opus { if !pkt_is_opus {
let _ = fec_dec.add_symbol( let _ = fec_dec.add_symbol(pkt_block, pkt_symbol, is_repair, &pkt.payload);
pkt_block,
pkt_symbol,
is_repair,
&pkt.payload,
);
} }
// Source packets: decode directly // Source packets: decode directly
@@ -951,8 +978,12 @@ async fn run_call(
fec_ratio: 0.5, fec_ratio: 0.5,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
..QualityProfile::GOOD
},
other => QualityProfile {
codec: other,
..QualityProfile::GOOD
}, },
other => QualityProfile { codec: other, ..QualityProfile::GOOD },
}; };
info!(from = ?decoder.codec_id(), to = ?pkt.header.codec_id, "recv: switching decoder"); info!(from = ?decoder.codec_id(), to = ?pkt.header.codec_id, "recv: switching decoder");
let _ = decoder.set_profile(switch_profile); let _ = decoder.set_profile(switch_profile);
@@ -984,10 +1015,7 @@ async fn run_call(
// Update DRED state from the current packet. // Update DRED state from the current packet.
match dred_decoder.parse_into(&mut dred_parse_scratch, &pkt.payload) { match dred_decoder.parse_into(&mut dred_parse_scratch, &pkt.payload) {
Ok(available) if available > 0 => { Ok(available) if available > 0 => {
std::mem::swap( std::mem::swap(&mut dred_parse_scratch, &mut last_good_dred);
&mut dred_parse_scratch,
&mut last_good_dred,
);
last_good_dred_seq = Some(pkt.header.seq); last_good_dred_seq = Some(pkt.header.seq);
} }
Ok(_) => { Ok(_) => {
@@ -999,15 +1027,14 @@ async fn run_call(
} }
// Detect and fill gap from last-expected to this packet. // Detect and fill gap from last-expected to this packet.
const MAX_GAP_FRAMES: u16 = 16; const MAX_GAP_FRAMES: u32 = 16;
if let Some(expected) = expected_seq { if let Some(expected) = expected_seq {
let gap = pkt.header.seq.wrapping_sub(expected); let gap = pkt.header.seq.wrapping_sub(expected);
if gap > 0 && gap <= MAX_GAP_FRAMES { if gap > 0 && gap <= MAX_GAP_FRAMES {
let current_profile_frame_samples = let current_profile_frame_samples =
(48_000 * profile.frame_duration_ms as i32) / 1000; (48_000 * profile.frame_duration_ms as i32) / 1000;
let available = last_good_dred.samples_available(); let available = last_good_dred.samples_available();
let pcm_slice_len = let pcm_slice_len = current_profile_frame_samples as usize;
current_profile_frame_samples as usize;
for gap_idx in 0..gap { for gap_idx in 0..gap {
let missing_seq = expected.wrapping_add(gap_idx); let missing_seq = expected.wrapping_add(gap_idx);
@@ -1026,9 +1053,8 @@ async fn run_call(
None => -1, None => -1,
}; };
let reconstructed = if offset_samples > 0 let reconstructed =
&& offset_samples <= available if offset_samples > 0 && offset_samples <= available {
{
decoder decoder
.reconstruct_from_dred( .reconstruct_from_dred(
&last_good_dred, &last_good_dred,
@@ -1042,12 +1068,9 @@ async fn run_call(
match reconstructed { match reconstructed {
Some(samples) => { Some(samples) => {
playout_agc.process_frame( playout_agc
&mut decode_buf[..samples], .process_frame(&mut decode_buf[..samples]);
); state.playout_ring.write(&decode_buf[..samples]);
state
.playout_ring
.write(&decode_buf[..samples]);
dred_reconstructions += 1; dred_reconstructions += 1;
frames_decoded += 1; frames_decoded += 1;
} }
@@ -1144,7 +1167,10 @@ async fn run_call(
} }
} }
Ok(None) => { Ok(None) => {
info!(frames_decoded, fec_recovered, "relay disconnected (stream ended)"); info!(
frames_decoded,
fec_recovered, "relay disconnected (stream ended)"
);
break; break;
} }
Err(e) => { Err(e) => {
@@ -1162,7 +1188,10 @@ async fn run_call(
} }
} }
} }
info!(frames_decoded, fec_recovered, recv_errors, "recv task ended"); info!(
frames_decoded,
fec_recovered, recv_errors, "recv task ended"
);
}; };
// Stats task — polls path quality + quinn RTT every 500ms // Stats task — polls path quality + quinn RTT every 500ms
@@ -1195,7 +1224,11 @@ async fn run_call(
let signal_task = async { let signal_task = async {
loop { loop {
match transport_signal.recv_signal().await { match transport_signal.recv_signal().await {
Ok(Some(SignalMessage::RoomUpdate { count, participants })) => { Ok(Some(SignalMessage::RoomUpdate {
count,
participants,
..
})) => {
info!(count, "RoomUpdate received"); info!(count, "RoomUpdate received");
let members: Vec<crate::stats::RoomMember> = participants let members: Vec<crate::stats::RoomMember> = participants
.iter() .iter()
@@ -1209,7 +1242,11 @@ async fn run_call(
stats.room_participant_count = count; stats.room_participant_count = count;
stats.room_participants = members; stats.room_participants = members;
} }
Ok(Some(SignalMessage::QualityDirective { recommended_profile, reason })) => { Ok(Some(SignalMessage::QualityDirective {
recommended_profile,
reason,
..
})) => {
let idx = profile_to_index(&recommended_profile); let idx = profile_to_index(&recommended_profile);
info!( info!(
codec = ?recommended_profile.codec, codec = ?recommended_profile.codec,
@@ -1247,7 +1284,9 @@ async fn run_call(
match tokio::time::timeout( match tokio::time::timeout(
std::time::Duration::from_millis(500), std::time::Duration::from_millis(500),
transport.connection().closed(), transport.connection().closed(),
).await { )
.await
{
Ok(_) => info!("QUIC connection closed cleanly"), Ok(_) => info!("QUIC connection closed cleanly"),
Err(_) => info!("QUIC close timed out (relay may not have ack'd)"), Err(_) => info!("QUIC close timed out (relay may not have ack'd)"),
} }

View File

@@ -3,9 +3,9 @@
use std::panic; use std::panic;
use std::sync::Once; use std::sync::Once;
use jni::JNIEnv;
use jni::objects::{JClass, JObject, JString}; use jni::objects::{JClass, JObject, JString};
use jni::sys::{jboolean, jint, jlong, jstring}; use jni::sys::{jboolean, jint, jlong, jstring};
use jni::JNIEnv;
use tracing::{error, info}; use tracing::{error, info};
use wzp_proto::QualityProfile; use wzp_proto::QualityProfile;
@@ -29,11 +29,13 @@ fn profile_from_int(value: jint) -> QualityProfile {
0 => QualityProfile::GOOD, // Opus 24k 0 => QualityProfile::GOOD, // Opus 24k
1 => QualityProfile::DEGRADED, // Opus 6k 1 => QualityProfile::DEGRADED, // Opus 6k
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k 2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
3 => QualityProfile { // Codec2 3.2k 3 => QualityProfile {
// Codec2 3.2k
codec: wzp_proto::CodecId::Codec2_3200, codec: wzp_proto::CodecId::Codec2_3200,
fec_ratio: 0.5, fec_ratio: 0.5,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
..QualityProfile::GOOD
}, },
4 => QualityProfile::STUDIO_32K, // Opus 32k 4 => QualityProfile::STUDIO_32K, // Opus 32k
5 => QualityProfile::STUDIO_48K, // Opus 48k 5 => QualityProfile::STUDIO_48K, // Opus 48k
@@ -48,6 +50,8 @@ static INIT_LOGGING: Once = Once::new();
/// Safe to call multiple times — only the first call takes effect. /// Safe to call multiple times — only the first call takes effect.
fn init_logging() { fn init_logging() {
INIT_LOGGING.call_once(|| { INIT_LOGGING.call_once(|| {
#[cfg(target_os = "android")]
{
// Wrap in catch_unwind — sharded_slab allocation inside // Wrap in catch_unwind — sharded_slab allocation inside
// tracing_subscriber::registry() can crash on some Android // tracing_subscriber::registry() can crash on some Android
// devices if scudo malloc fails during early initialization. // devices if scudo malloc fails during early initialization.
@@ -67,6 +71,12 @@ fn init_logging() {
.try_init(); .try_init();
} }
}); });
}
#[cfg(not(target_os = "android"))]
{
// On non-Android targets tracing-android is unavailable.
let _ = tracing_subscriber::fmt::try_init();
}
}); });
} }
@@ -101,11 +111,26 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
profile_j: jint, profile_j: jint,
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default(); let relay_addr: String = env
let room: String = env.get_string(&room_j).map(|s| s.into()).unwrap_or_default(); .get_string(&relay_addr_j)
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default(); .map(|s| s.into())
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default(); .unwrap_or_default();
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default(); let room: String = env
.get_string(&room_j)
.map(|s| s.into())
.unwrap_or_default();
let seed_hex: String = env
.get_string(&seed_hex_j)
.map(|s| s.into())
.unwrap_or_default();
let token: String = env
.get_string(&token_j)
.map(|s| s.into())
.unwrap_or_default();
let alias: String = env
.get_string(&alias_j)
.map(|s| s.into())
.unwrap_or_default();
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
@@ -128,7 +153,11 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
auto_profile: profile_j == PROFILE_AUTO, auto_profile: profile_j == PROFILE_AUTO,
relay_addr, relay_addr,
room, room,
auth_token: if token.is_empty() { Vec::new() } else { token.into_bytes() }, auth_token: if token.is_empty() {
Vec::new()
} else {
token.into_bytes()
},
identity_seed, identity_seed,
alias: if alias.is_empty() { None } else { Some(alias) }, alias: if alias.is_empty() { None } else { Some(alias) },
}; };
@@ -241,7 +270,8 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeOnNetworkChang
) { ) {
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
h.engine.on_network_changed(network_type as u8, bandwidth_kbps as u32); h.engine
.on_network_changed(network_type as u8, bandwidth_kbps as u32);
})); }));
} }
@@ -307,13 +337,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudioDire
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut()); let ptr = env
.get_direct_buffer_address(&buffer)
.unwrap_or(std::ptr::null_mut());
if ptr.is_null() || sample_count <= 0 { if ptr.is_null() || sample_count <= 0 {
return 0; return 0;
} }
let samples = unsafe { let samples =
std::slice::from_raw_parts(ptr as *const i16, sample_count as usize) unsafe { std::slice::from_raw_parts(ptr as *const i16, sample_count as usize) };
};
h.engine.write_audio(samples) as jint h.engine.write_audio(samples) as jint
})); }));
result.unwrap_or(0) result.unwrap_or(0)
@@ -332,13 +363,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudioDirec
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut()); let ptr = env
.get_direct_buffer_address(&buffer)
.unwrap_or(std::ptr::null_mut());
if ptr.is_null() || max_samples <= 0 { if ptr.is_null() || max_samples <= 0 {
return 0; return 0;
} }
let samples = unsafe { let samples =
std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize) unsafe { std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize) };
};
h.engine.read_audio(samples) as jint h.engine.read_audio(samples) as jint
})); }));
result.unwrap_or(0) result.unwrap_or(0)
@@ -367,7 +399,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePingRelay<'a>(
) -> jstring { ) -> jstring {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let relay: String = env.get_string(&relay_j).map(|s| s.into()).unwrap_or_default(); let relay: String = env
.get_string(&relay_j)
.map(|s| s.into())
.unwrap_or_default();
match h.engine.ping_relay(&relay) { match h.engine.ping_relay(&relay) {
Ok(json) => Some(json), Ok(json) => Some(json),
Err(_) => None, Err(_) => None,
@@ -399,10 +434,22 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default(); let relay_addr: String = env
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default(); .get_string(&relay_addr_j)
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default(); .map(|s| s.into())
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default(); .unwrap_or_default();
let seed_hex: String = env
.get_string(&seed_hex_j)
.map(|s| s.into())
.unwrap_or_default();
let token: String = env
.get_string(&token_j)
.map(|s| s.into())
.unwrap_or_default();
let alias: String = env
.get_string(&alias_j)
.map(|s| s.into())
.unwrap_or_default();
h.engine.start_signaling( h.engine.start_signaling(
&relay_addr, &relay_addr,
@@ -414,8 +461,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
match result { match result {
Ok(Ok(())) => 0, Ok(Ok(())) => 0,
Ok(Err(e)) => { error!("start_signaling failed: {e}"); -1 } Ok(Err(e)) => {
Err(_) => { error!("start_signaling panicked"); -1 } error!("start_signaling failed: {e}");
-1
}
Err(_) => {
error!("start_signaling panicked");
-1
}
} }
} }
@@ -430,14 +483,23 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePlaceCall<'a>(
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let target: String = env.get_string(&target_fp_j).map(|s| s.into()).unwrap_or_default(); let target: String = env
.get_string(&target_fp_j)
.map(|s| s.into())
.unwrap_or_default();
h.engine.place_call(&target) h.engine.place_call(&target)
})); }));
match result { match result {
Ok(Ok(())) => 0, Ok(Ok(())) => 0,
Ok(Err(e)) => { error!("place_call failed: {e}"); -1 } Ok(Err(e)) => {
Err(_) => { error!("place_call panicked"); -1 } error!("place_call failed: {e}");
-1
}
Err(_) => {
error!("place_call panicked");
-1
}
} }
} }
@@ -453,7 +515,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
) -> jint { ) -> jint {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let h = unsafe { handle_ref(handle) }; let h = unsafe { handle_ref(handle) };
let call_id: String = env.get_string(&call_id_j).map(|s| s.into()).unwrap_or_default(); let call_id: String = env
.get_string(&call_id_j)
.map(|s| s.into())
.unwrap_or_default();
let accept_mode = match mode { let accept_mode = match mode {
0 => wzp_proto::CallAcceptMode::Reject, 0 => wzp_proto::CallAcceptMode::Reject,
1 => wzp_proto::CallAcceptMode::AcceptTrusted, 1 => wzp_proto::CallAcceptMode::AcceptTrusted,
@@ -464,7 +529,13 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
match result { match result {
Ok(Ok(())) => 0, Ok(Ok(())) => 0,
Ok(Err(e)) => { error!("answer_call failed: {e}"); -1 } Ok(Err(e)) => {
Err(_) => { error!("answer_call panicked"); -1 } error!("answer_call failed: {e}");
-1
}
Err(_) => {
error!("answer_call panicked");
-1
}
} }
} }

View File

@@ -26,6 +26,6 @@ pub mod audio_android;
pub mod audio_ring; pub mod audio_ring;
pub mod commands; pub mod commands;
pub mod engine; pub mod engine;
pub mod jni_bridge;
pub mod pipeline; pub mod pipeline;
pub mod stats; pub mod stats;
pub mod jni_bridge;

View File

@@ -9,8 +9,8 @@ use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder, AutoGainControl, EchoCanceller
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::jitter::{JitterBuffer, PlayoutResult}; use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
use wzp_proto::quality::AdaptiveQualityController; use wzp_proto::quality::AdaptiveQualityController;
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
use wzp_proto::traits::QualityController; use wzp_proto::traits::QualityController;
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
use wzp_proto::{MediaPacket, QualityProfile}; use wzp_proto::{MediaPacket, QualityProfile};
use crate::audio_android::FRAME_SAMPLES; use crate::audio_android::FRAME_SAMPLES;
@@ -58,14 +58,12 @@ pub struct Pipeline {
impl Pipeline { impl Pipeline {
/// Create a new pipeline configured for the given quality profile. /// Create a new pipeline configured for the given quality profile.
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> { pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
let encoder = AdaptiveEncoder::new(profile) let encoder =
.map_err(|e| anyhow::anyhow!("encoder init: {e}"))?; AdaptiveEncoder::new(profile).map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
let decoder = AdaptiveDecoder::new(profile) let decoder =
.map_err(|e| anyhow::anyhow!("decoder init: {e}"))?; AdaptiveDecoder::new(profile).map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
let fec_encoder = let fec_encoder = RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize); let fec_decoder = RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
let fec_decoder =
RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
let jitter_buffer = JitterBuffer::new(10, 250, 3); let jitter_buffer = JitterBuffer::new(10, 250, 3);
let quality_ctrl = AdaptiveQualityController::new(); let quality_ctrl = AdaptiveQualityController::new();
@@ -136,11 +134,11 @@ impl Pipeline {
pub fn feed_packet(&mut self, packet: MediaPacket) { pub fn feed_packet(&mut self, packet: MediaPacket) {
// Feed FEC symbols if present // Feed FEC symbols if present
let header = &packet.header; let header = &packet.header;
if header.fec_block != 0 || header.fec_symbol != 0 { if header.fec_block != 0 {
let is_repair = header.is_repair; let is_repair = header.is_repair();
if let Err(e) = self.fec_decoder.add_symbol( if let Err(e) = self.fec_decoder.add_symbol(
header.fec_block, header.fec_block,
header.fec_symbol, header.fec_block >> 8,
is_repair, is_repair,
&packet.payload, &packet.payload,
) { ) {
@@ -211,10 +209,7 @@ impl Pipeline {
/// ///
/// Returns a new profile if a tier transition occurred. /// Returns a new profile if a tier transition occurred.
#[allow(unused)] #[allow(unused)]
pub fn observe_quality( pub fn observe_quality(&mut self, report: &wzp_proto::QualityReport) -> Option<QualityProfile> {
&mut self,
report: &wzp_proto::QualityReport,
) -> Option<QualityProfile> {
let new_profile = self.quality_ctrl.observe(report); let new_profile = self.quality_ctrl.observe(report);
if let Some(ref profile) = new_profile { if let Some(ref profile) = new_profile {
if let Err(e) = self.encoder.set_profile(*profile) { if let Err(e) = self.encoder.set_profile(*profile) {

View File

@@ -12,6 +12,7 @@ wzp-codec = { workspace = true }
wzp-fec = { workspace = true } wzp-fec = { workspace = true }
wzp-crypto = { workspace = true } wzp-crypto = { workspace = true }
wzp-transport = { workspace = true } wzp-transport = { workspace = true }
wzp-video = { path = "../wzp-video" }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }

View File

@@ -15,7 +15,7 @@ use std::time::{Duration, Instant};
use clap::Parser; use clap::Parser;
use tracing::info; use tracing::info;
use wzp_proto::{CodecId, MediaPacket, MediaTransport}; use wzp_proto::{CodecId, MediaPacket, MediaTransport, default_signal_version};
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// CLI // CLI
@@ -86,7 +86,7 @@ struct ParticipantStats {
/// Detected lost packets (sequence gaps) /// Detected lost packets (sequence gaps)
lost: u64, lost: u64,
/// Last seen sequence number /// Last seen sequence number
last_seq: u16, last_seq: u32,
/// Whether we've seen the first packet (for gap detection) /// Whether we've seen the first packet (for gap detection)
seq_initialized: bool, seq_initialized: bool,
/// EWMA jitter in ms /// EWMA jitter in ms
@@ -181,7 +181,7 @@ impl ParticipantStats {
/// distinguish streams by proximity of consecutive sequence numbers. /// distinguish streams by proximity of consecutive sequence numbers.
fn find_or_create_participant( fn find_or_create_participant(
participants: &mut Vec<ParticipantStats>, participants: &mut Vec<ParticipantStats>,
seq: u16, seq: u32,
codec: CodecId, codec: CodecId,
) -> usize { ) -> usize {
for (i, p) in participants.iter().enumerate() { for (i, p) in participants.iter().enumerate() {
@@ -304,7 +304,7 @@ struct TimelineEntry {
#[allow(dead_code)] #[allow(dead_code)]
codec: CodecId, codec: CodecId,
#[allow(dead_code)] #[allow(dead_code)]
seq: u16, seq: u32,
#[allow(dead_code)] #[allow(dead_code)]
payload_len: usize, payload_len: usize,
loss_pct: f64, loss_pct: f64,
@@ -333,8 +333,11 @@ async fn run_replay(path: &str, args: &Args) -> anyhow::Result<()> {
let mut timeline: Vec<TimelineEntry> = Vec::new(); let mut timeline: Vec<TimelineEntry> = Vec::new();
// Decrypt session from --key (optional) // Decrypt session from --key (optional)
let mut decrypt_session: Option<wzp_crypto::ChaChaSession> = args.key.as_ref().and_then(|hex| { let mut decrypt_session: Option<wzp_crypto::ChaChaSession> =
if hex.len() != 64 { return None; } args.key.as_ref().and_then(|hex| {
if hex.len() != 64 {
return None;
}
let mut key = [0u8; 32]; let mut key = [0u8; 32];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() { for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let s = std::str::from_utf8(chunk).unwrap_or("00"); let s = std::str::from_utf8(chunk).unwrap_or("00");
@@ -347,7 +350,8 @@ async fn run_replay(path: &str, args: &Args) -> anyhow::Result<()> {
while let Some((ts_us, pkt)) = reader.next_packet()? { while let Some((ts_us, pkt)) = reader.next_packet()? {
let now = Instant::now(); let now = Instant::now();
let idx = find_or_create_participant(&mut participants, pkt.header.seq, pkt.header.codec_id); let idx =
find_or_create_participant(&mut participants, pkt.header.seq, pkt.header.codec_id);
participants[idx].ingest(&pkt, now); participants[idx].ingest(&pkt, now);
total_packets += 1; total_packets += 1;
@@ -362,8 +366,10 @@ async fn run_replay(path: &str, args: &Args) -> anyhow::Result<()> {
if decrypt_ok <= 5 || decrypt_ok % 100 == 0 { if decrypt_ok <= 5 || decrypt_ok % 100 == 0 {
eprintln!( eprintln!(
" decrypt ok: seq={} codec={:?} payload={}B → plaintext={}B", " decrypt ok: seq={} codec={:?} payload={}B → plaintext={}B",
pkt.header.seq, pkt.header.codec_id, pkt.header.seq,
pkt.payload.len(), plaintext.len() pkt.header.codec_id,
pkt.payload.len(),
plaintext.len()
); );
} }
} }
@@ -402,7 +408,13 @@ async fn run_replay(path: &str, args: &Args) -> anyhow::Result<()> {
// Generate HTML if requested // Generate HTML if requested
if let Some(html_path) = &args.html { if let Some(html_path) = &args.html {
generate_html_report(html_path, &participants, &timeline, total_packets, &reader.header)?; generate_html_report(
html_path,
&participants,
&timeline,
total_packets,
&reader.header,
)?;
eprintln!("HTML report: {}", html_path); eprintln!("HTML report: {}", html_path);
} }
@@ -603,7 +615,11 @@ async fn run_no_tui(
} }
fn print_stats(participants: &[ParticipantStats], total: u64) { fn print_stats(participants: &[ParticipantStats], total: u64) {
eprintln!("--- {} participants | {} total packets ---", participants.len(), total); eprintln!(
"--- {} participants | {} total packets ---",
participants.len(),
total
);
for p in participants { for p in participants {
eprintln!( eprintln!(
" {}: {} pkts, {:.1}% loss, {:.0}ms jitter, {:?}, {:.0}s", " {}: {} pkts, {:.1}% loss, {:.0}ms jitter, {:?}, {:.0}s",
@@ -693,10 +709,7 @@ async fn run_tui(
// Always restore terminal, even on error // Always restore terminal, even on error
crossterm::terminal::disable_raw_mode()?; crossterm::terminal::disable_raw_mode()?;
crossterm::execute!( crossterm::execute!(std::io::stdout(), crossterm::terminal::LeaveAlternateScreen)?;
std::io::stdout(),
crossterm::terminal::LeaveAlternateScreen
)?;
result result
} }
@@ -735,7 +748,11 @@ fn draw_ui(
total_packets, total_packets,
elapsed_str elapsed_str
)) ))
.block(Block::default().borders(Borders::ALL).title(" Protocol Analyzer ")); .block(
Block::default()
.borders(Borders::ALL)
.title(" Protocol Analyzer "),
);
f.render_widget(header, chunks[0]); f.render_widget(header, chunks[0]);
// Participant table // Participant table
@@ -780,9 +797,11 @@ fn draw_ui(
Constraint::Length(10), // Duration Constraint::Length(10), // Duration
]; ];
let table = Table::new(rows, widths) let table = Table::new(rows, widths).header(header_row).block(
.header(header_row) Block::default()
.block(Block::default().borders(Borders::ALL).title(" Participants ")); .borders(Borders::ALL)
.title(" Participants "),
);
f.render_widget(table, chunks[1]); f.render_widget(table, chunks[1]);
// Footer // Footer
@@ -832,7 +851,10 @@ async fn main() -> anyhow::Result<()> {
let _crypto_session: Option<std::sync::Mutex<wzp_crypto::ChaChaSession>> = let _crypto_session: Option<std::sync::Mutex<wzp_crypto::ChaChaSession>> =
if let Some(ref key_hex) = args.key { if let Some(ref key_hex) = args.key {
if key_hex.len() != 64 { if key_hex.len() != 64 {
eprintln!("Error: --key must be 64 hex characters (32 bytes). Got {} chars.", key_hex.len()); eprintln!(
"Error: --key must be 64 hex characters (32 bytes). Got {} chars.",
key_hex.len()
);
std::process::exit(1); std::process::exit(1);
} }
let mut key_bytes = [0u8; 32]; let mut key_bytes = [0u8; 32];
@@ -841,9 +863,9 @@ async fn main() -> anyhow::Result<()> {
key_bytes[i] = u8::from_str_radix(hex_str, 16).unwrap_or(0); key_bytes[i] = u8::from_str_radix(hex_str, 16).unwrap_or(0);
} }
eprintln!("Encrypted payload decoding enabled (key loaded)."); eprintln!("Encrypted payload decoding enabled (key loaded).");
Some(std::sync::Mutex::new( Some(std::sync::Mutex::new(wzp_crypto::ChaChaSession::new(
wzp_crypto::ChaChaSession::new(key_bytes), key_bytes,
)) )))
} else { } else {
None None
}; };
@@ -854,14 +876,12 @@ async fn main() -> anyhow::Result<()> {
} }
// Live mode requires relay and room // Live mode requires relay and room
let relay = args let relay = args.relay.as_deref().ok_or_else(|| {
.relay anyhow::anyhow!("relay address required for live mode (use --replay for offline)")
.as_deref() })?;
.ok_or_else(|| anyhow::anyhow!("relay address required for live mode (use --replay for offline)"))?; let room = args.room.as_deref().ok_or_else(|| {
let room = args anyhow::anyhow!("--room required for live mode (use --replay for offline)")
.room })?;
.as_deref()
.ok_or_else(|| anyhow::anyhow!("--room required for live mode (use --replay for offline)"))?;
// TLS crypto provider // TLS crypto provider
let _ = rustls::crypto::ring::default_provider().install_default(); let _ = rustls::crypto::ring::default_provider().install_default();
@@ -899,6 +919,7 @@ async fn main() -> anyhow::Result<()> {
// Auth if token provided // Auth if token provided
if let Some(ref token) = args.token { if let Some(ref token) = args.token {
let auth = wzp_proto::SignalMessage::AuthToken { let auth = wzp_proto::SignalMessage::AuthToken {
version: default_signal_version(),
token: token.clone(), token: token.clone(),
}; };
transport.send_signal(&auth).await?; transport.send_signal(&auth).await?;

View File

@@ -6,10 +6,10 @@
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing` //! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path. //! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use anyhow::{anyhow, Context}; use anyhow::{Context, anyhow};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{SampleFormat, SampleRate, StreamConfig}; use cpal::{SampleFormat, SampleRate, StreamConfig};
use tracing::{info, warn}; use tracing::{info, warn};
@@ -78,7 +78,10 @@ impl AudioCapture {
return; return;
} }
if !logged.swap(true, Ordering::Relaxed) { if !logged.swap(true, Ordering::Relaxed) {
eprintln!("[audio] capture callback: {} f32 samples", data.len()); eprintln!(
"[audio] capture callback: {} f32 samples",
data.len()
);
} }
let mut tmp = [0i16; FRAME_SAMPLES]; let mut tmp = [0i16; FRAME_SAMPLES];
for chunk in data.chunks(FRAME_SAMPLES) { for chunk in data.chunks(FRAME_SAMPLES) {
@@ -103,7 +106,10 @@ impl AudioCapture {
return; return;
} }
if !logged.swap(true, Ordering::Relaxed) { if !logged.swap(true, Ordering::Relaxed) {
eprintln!("[audio] capture callback: {} i16 samples", data.len()); eprintln!(
"[audio] capture callback: {} i16 samples",
data.len()
);
} }
ring.write(data); ring.write(data);
}, },

View File

@@ -54,13 +54,13 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
use anyhow::{anyhow, Context}; use anyhow::{Context, anyhow};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{SampleFormat, SampleRate, StreamConfig}; use cpal::{SampleFormat, SampleRate, StreamConfig};
use tracing::{info, warn}; use tracing::{info, warn};
use webrtc_audio_processing::{ use webrtc_audio_processing::{
Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig, Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig,
NoiseSuppression, NoiseSuppressionLevel, Processor, NUM_SAMPLES_PER_FRAME, NUM_SAMPLES_PER_FRAME, NoiseSuppression, NoiseSuppressionLevel, Processor,
}; };
use crate::audio_ring::AudioRing; use crate::audio_ring::AudioRing;
@@ -97,8 +97,8 @@ fn get_or_init_processor() -> anyhow::Result<Arc<Mutex<Processor>>> {
num_render_channels: APM_NUM_CHANNELS as i32, num_render_channels: APM_NUM_CHANNELS as i32,
..Default::default() ..Default::default()
}; };
let mut processor = Processor::new(&init_config) let mut processor =
.map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?; Processor::new(&init_config).map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
let config = Config { let config = Config {
echo_cancellation: Some(EchoCancellation { echo_cancellation: Some(EchoCancellation {

View File

@@ -5,8 +5,8 @@
//! to the speaker, so it can cancel the echo from the mic signal internally. //! to the speaker, so it can cancel the echo from the mic signal internally.
//! This is the same engine FaceTime and other Apple apps use. //! This is the same engine FaceTime and other Apple apps use.
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use anyhow::Context; use anyhow::Context;
use coreaudio::audio_unit::audio_format::LinearPcmFlags; use coreaudio::audio_unit::audio_format::LinearPcmFlags;
@@ -28,6 +28,60 @@ pub struct VpioAudio {
playout_ring: Arc<AudioRing>, playout_ring: Arc<AudioRing>,
_audio_unit: AudioUnit, _audio_unit: AudioUnit,
running: Arc<AtomicBool>, running: Arc<AtomicBool>,
stats: Arc<VpioStats>,
}
/// Render/capture counters for diagnosing macOS VoiceProcessingIO.
///
/// These are atomics because CoreAudio callbacks run on realtime audio
/// threads. The Tauri engine polls snapshots from a normal async task and
/// emits them to the call debug log.
#[derive(Default)]
pub struct VpioStats {
capture_callbacks: AtomicU64,
capture_samples: AtomicU64,
render_callbacks: AtomicU64,
render_requested_samples: AtomicU64,
render_read_samples: AtomicU64,
render_underrun_callbacks: AtomicU64,
render_nonzero_callbacks: AtomicU64,
render_last_requested: AtomicU64,
render_last_read: AtomicU64,
render_last_rms: AtomicU64,
render_last_ring_available: AtomicU64,
}
#[derive(Clone, Copy, Debug)]
pub struct VpioStatsSnapshot {
pub capture_callbacks: u64,
pub capture_samples: u64,
pub render_callbacks: u64,
pub render_requested_samples: u64,
pub render_read_samples: u64,
pub render_underrun_callbacks: u64,
pub render_nonzero_callbacks: u64,
pub render_last_requested: u64,
pub render_last_read: u64,
pub render_last_rms: u64,
pub render_last_ring_available: u64,
}
impl VpioStats {
pub fn snapshot(&self) -> VpioStatsSnapshot {
VpioStatsSnapshot {
capture_callbacks: self.capture_callbacks.load(Ordering::Relaxed),
capture_samples: self.capture_samples.load(Ordering::Relaxed),
render_callbacks: self.render_callbacks.load(Ordering::Relaxed),
render_requested_samples: self.render_requested_samples.load(Ordering::Relaxed),
render_read_samples: self.render_read_samples.load(Ordering::Relaxed),
render_underrun_callbacks: self.render_underrun_callbacks.load(Ordering::Relaxed),
render_nonzero_callbacks: self.render_nonzero_callbacks.load(Ordering::Relaxed),
render_last_requested: self.render_last_requested.load(Ordering::Relaxed),
render_last_read: self.render_last_read.load(Ordering::Relaxed),
render_last_rms: self.render_last_rms.load(Ordering::Relaxed),
render_last_ring_available: self.render_last_ring_available.load(Ordering::Relaxed),
}
}
} }
impl VpioAudio { impl VpioAudio {
@@ -36,6 +90,7 @@ impl VpioAudio {
let capture_ring = Arc::new(AudioRing::new()); let capture_ring = Arc::new(AudioRing::new());
let playout_ring = Arc::new(AudioRing::new()); let playout_ring = Arc::new(AudioRing::new());
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let stats = Arc::new(VpioStats::default());
let mut au = AudioUnit::new(IOType::VoiceProcessingIO) let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
.context("failed to create VoiceProcessingIO audio unit")?; .context("failed to create VoiceProcessingIO audio unit")?;
@@ -98,6 +153,7 @@ impl VpioAudio {
// Set up input callback (mic capture with AEC applied) // Set up input callback (mic capture with AEC applied)
let cap_ring = capture_ring.clone(); let cap_ring = capture_ring.clone();
let cap_running = running.clone(); let cap_running = running.clone();
let cap_stats = stats.clone();
let logged = Arc::new(AtomicBool::new(false)); let logged = Arc::new(AtomicBool::new(false));
au.set_input_callback( au.set_input_callback(
move |args: render_callback::Args<data::NonInterleaved<f32>>| { move |args: render_callback::Args<data::NonInterleaved<f32>>| {
@@ -106,6 +162,10 @@ impl VpioAudio {
} }
let mut buffers = args.data.channels(); let mut buffers = args.data.channels();
if let Some(ch) = buffers.next() { if let Some(ch) = buffers.next() {
cap_stats.capture_callbacks.fetch_add(1, Ordering::Relaxed);
cap_stats
.capture_samples
.fetch_add(ch.len() as u64, Ordering::Relaxed);
if !logged.swap(true, Ordering::Relaxed) { if !logged.swap(true, Ordering::Relaxed) {
eprintln!("[vpio] capture callback: {} f32 samples", ch.len()); eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
} }
@@ -125,28 +185,80 @@ impl VpioAudio {
// Set up output callback (speaker playback — AEC uses this as reference) // Set up output callback (speaker playback — AEC uses this as reference)
let play_ring = playout_ring.clone(); let play_ring = playout_ring.clone();
let render_stats = stats.clone();
let logged_render = Arc::new(AtomicBool::new(false));
au.set_render_callback( au.set_render_callback(
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| { move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
let mut buffers = args.data.channels_mut(); let mut buffers = args.data.channels_mut();
if let Some(ch) = buffers.next() { if let Some(ch) = buffers.next() {
render_stats
.render_callbacks
.fetch_add(1, Ordering::Relaxed);
render_stats
.render_requested_samples
.fetch_add(ch.len() as u64, Ordering::Relaxed);
render_stats
.render_last_requested
.store(ch.len() as u64, Ordering::Relaxed);
let mut tmp = [0i16; FRAME_SAMPLES]; let mut tmp = [0i16; FRAME_SAMPLES];
let mut total_read = 0usize;
let mut sum_sq = 0u64;
let ring_available = play_ring.available();
for chunk in ch.chunks_mut(FRAME_SAMPLES) { for chunk in ch.chunks_mut(FRAME_SAMPLES) {
let n = chunk.len(); let n = chunk.len();
let read = play_ring.read(&mut tmp[..n]); let read = play_ring.read(&mut tmp[..n]);
total_read += read;
for i in 0..read { for i in 0..read {
let s = tmp[i] as i64;
sum_sq = sum_sq.saturating_add((s * s) as u64);
chunk[i] = tmp[i] as f32 / i16::MAX as f32; chunk[i] = tmp[i] as f32 / i16::MAX as f32;
} }
for i in read..n { for i in read..n {
chunk[i] = 0.0; chunk[i] = 0.0;
} }
} }
render_stats
.render_read_samples
.fetch_add(total_read as u64, Ordering::Relaxed);
render_stats
.render_last_read
.store(total_read as u64, Ordering::Relaxed);
render_stats
.render_last_ring_available
.store(ring_available as u64, Ordering::Relaxed);
if total_read == 0 {
render_stats
.render_underrun_callbacks
.fetch_add(1, Ordering::Relaxed);
}
let rms = if total_read > 0 {
((sum_sq as f64 / total_read as f64).sqrt()) as u64
} else {
0
};
render_stats.render_last_rms.store(rms, Ordering::Relaxed);
if rms > 0 {
render_stats
.render_nonzero_callbacks
.fetch_add(1, Ordering::Relaxed);
}
if !logged_render.swap(true, Ordering::Relaxed) {
eprintln!(
"[vpio] render callback: {} f32 samples, ring_available={}, ring_read={}, rms={}",
ch.len(),
ring_available,
total_read,
rms
);
}
} }
Ok(()) Ok(())
}, },
) )
.context("failed to set render callback")?; .context("failed to set render callback")?;
au.initialize().context("failed to initialize VoiceProcessingIO")?; au.initialize()
.context("failed to initialize VoiceProcessingIO")?;
au.start().context("failed to start VoiceProcessingIO")?; au.start().context("failed to start VoiceProcessingIO")?;
info!("VoiceProcessingIO started (OS-level AEC enabled)"); info!("VoiceProcessingIO started (OS-level AEC enabled)");
@@ -156,6 +268,7 @@ impl VpioAudio {
playout_ring, playout_ring,
_audio_unit: au, _audio_unit: au,
running, running,
stats,
}) })
} }
@@ -167,6 +280,10 @@ impl VpioAudio {
&self.playout_ring &self.playout_ring
} }
pub fn stats(&self) -> Arc<VpioStats> {
self.stats.clone()
}
pub fn stop(&self) { pub fn stop(&self) {
self.running.store(false, Ordering::Relaxed); self.running.store(false, Ordering::Relaxed);
} }

View File

@@ -15,24 +15,24 @@
//! `wzp-client`'s lib.rs can transparently re-export either one as //! `wzp-client`'s lib.rs can transparently re-export either one as
//! `AudioCapture`. //! `AudioCapture`.
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use anyhow::{anyhow, Context}; use anyhow::{Context, anyhow};
use tracing::{info, warn}; use tracing::{info, warn};
use windows::core::{Interface, GUID}; use windows::Win32::Foundation::{BOOL, CloseHandle, WAIT_OBJECT_0};
use windows::Win32::Foundation::{CloseHandle, BOOL, WAIT_OBJECT_0};
use windows::Win32::Media::Audio::{ use windows::Win32::Media::Audio::{
eCapture, eCommunications, AudioCategory_Communications, AudioClientProperties,
IAudioCaptureClient, IAudioClient, IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator,
AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM, AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM,
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY, WAVEFORMATEX, AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY,
WAVE_FORMAT_PCM, AudioCategory_Communications, AudioClientProperties, IAudioCaptureClient, IAudioClient,
IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator, WAVE_FORMAT_PCM, WAVEFORMATEX,
eCapture, eCommunications,
}; };
use windows::Win32::System::Com::{ use windows::Win32::System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
}; };
use windows::Win32::System::Threading::{CreateEventW, WaitForSingleObject, INFINITE}; use windows::Win32::System::Threading::{CreateEventW, INFINITE, WaitForSingleObject};
use windows::core::{GUID, Interface};
use crate::audio_ring::AudioRing; use crate::audio_ring::AudioRing;
@@ -138,8 +138,7 @@ unsafe fn capture_thread_main(
} }
let _com_guard = ComGuard; let _com_guard = ComGuard;
let enumerator: IMMDeviceEnumerator = let enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
.context("CoCreateInstance(MMDeviceEnumerator) failed")?; .context("CoCreateInstance(MMDeviceEnumerator) failed")?;
// eCommunications role (not eConsole) — this picks the device the user // eCommunications role (not eConsole) — this picks the device the user
@@ -206,12 +205,13 @@ unsafe fn capture_thread_main(
&wave_format, &wave_format,
Some(&GUID::zeroed()), Some(&GUID::zeroed()),
) )
.context("IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16")?; .context(
"IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16",
)?;
// Event-driven capture: Windows signals this handle each time a new // Event-driven capture: Windows signals this handle each time a new
// audio packet is available. We wait on it from the loop below. // audio packet is available. We wait on it from the loop below.
let event = CreateEventW(None, false, false, None) let event = CreateEventW(None, false, false, None).context("CreateEventW failed")?;
.context("CreateEventW failed")?;
audio_client audio_client
.SetEventHandle(event) .SetEventHandle(event)
.context("SetEventHandle failed")?; .context("SetEventHandle failed")?;
@@ -285,10 +285,8 @@ unsafe fn capture_thread_main(
// Because we asked for 48 kHz mono i16, each frame is // Because we asked for 48 kHz mono i16, each frame is
// exactly one i16. Windows's AUTOCONVERTPCM handles the // exactly one i16. Windows's AUTOCONVERTPCM handles the
// conversion from whatever the engine mix format is. // conversion from whatever the engine mix format is.
let samples = std::slice::from_raw_parts( let samples =
buffer_ptr as *const i16, std::slice::from_raw_parts(buffer_ptr as *const i16, num_frames as usize);
num_frames as usize,
);
ring.write(samples); ring.write(samples);
} }

View File

@@ -6,8 +6,8 @@ use std::time::{Duration, Instant};
use wzp_crypto::ChaChaSession; use wzp_crypto::ChaChaSession;
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
use wzp_proto::QualityProfile; use wzp_proto::QualityProfile;
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
use crate::call::{CallConfig, CallDecoder, CallEncoder}; use crate::call::{CallConfig, CallDecoder, CallEncoder};
@@ -151,7 +151,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
let mut total_repair_bytes = 0usize; let mut total_repair_bytes = 0usize;
for block_idx in 0..num_blocks { for block_idx in 0..num_blocks {
let block_id = (block_idx % 256) as u8; let block_id = (block_idx % 65536) as u16;
// Create fresh encoder and decoder for each block // Create fresh encoder and decoder for each block
let mut fec_enc = RaptorQFecEncoder::new(frames_per_block, 256); let mut fec_enc = RaptorQFecEncoder::new(frames_per_block, 256);
@@ -170,7 +170,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
// Collect all symbols: source + repair // Collect all symbols: source + repair
struct Symbol { struct Symbol {
index: u8, index: u16,
is_repair: bool, is_repair: bool,
data: Vec<u8>, data: Vec<u8>,
} }
@@ -180,7 +180,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
// For add_symbol we need to provide the raw data; the decoder pads internally // For add_symbol we need to provide the raw data; the decoder pads internally
total_source_bytes += sym.len(); total_source_bytes += sym.len();
all_symbols.push(Symbol { all_symbols.push(Symbol {
index: i as u8, index: i as u16,
is_repair: false, is_repair: false,
data: sym.clone(), data: sym.clone(),
}); });
@@ -201,9 +201,13 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
// Deterministic shuffle for reproducibility using a simple seed // Deterministic shuffle for reproducibility using a simple seed
// We use a basic Fisher-Yates with a fixed-per-block seed // We use a basic Fisher-Yates with a fixed-per-block seed
let mut indices: Vec<usize> = (0..all_symbols.len()).collect(); let mut indices: Vec<usize> = (0..all_symbols.len()).collect();
let mut seed = (block_idx as u64).wrapping_mul(6364136223846793005).wrapping_add(1); let mut seed = (block_idx as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1);
for i in (1..indices.len()).rev() { for i in (1..indices.len()).rev() {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); seed = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let j = (seed >> 33) as usize % (i + 1); let j = (seed >> 33) as usize % (i + 1);
indices.swap(i, j); indices.swap(i, j);
} }
@@ -259,17 +263,36 @@ pub fn bench_encrypt_decrypt() -> CryptoResult {
}) })
.collect(); .collect();
let header = b"bench-header"; // Build valid v2 MediaHeader bytes — encrypt/decrypt now derive nonces from
// header.seq and require a parseable MediaHeader (WIRE_SIZE bytes minimum).
use wzp_proto::packet::MediaHeader;
use wzp_proto::{CodecId, MediaType};
let mut total_bytes: usize = 0; let mut total_bytes: usize = 0;
let start = Instant::now(); let start = Instant::now();
for payload in &payloads { for (i, payload) in payloads.iter().enumerate() {
let hdr = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq: i as u32,
timestamp: (i as u32).wrapping_mul(20),
fec_block: 0,
};
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
hdr.write_to(&mut header_bytes);
let mut ciphertext = Vec::with_capacity(payload.len() + 16); let mut ciphertext = Vec::with_capacity(payload.len() + 16);
encryptor.encrypt(header, payload, &mut ciphertext).unwrap(); encryptor
.encrypt(&header_bytes, payload, &mut ciphertext)
.unwrap();
let mut plaintext = Vec::with_capacity(payload.len()); let mut plaintext = Vec::with_capacity(payload.len());
decryptor decryptor
.decrypt(header, &ciphertext, &mut plaintext) .decrypt(&header_bytes, &ciphertext, &mut plaintext)
.unwrap(); .unwrap();
total_bytes += payload.len(); total_bytes += payload.len();

View File

@@ -24,8 +24,14 @@ fn run_codec() {
print_header("Codec Roundtrip (Opus 24kbps)"); print_header("Codec Roundtrip (Opus 24kbps)");
let r = bench::bench_codec_roundtrip(); let r = bench::bench_codec_roundtrip();
print_row("Frames", &format!("{}", r.frames)); print_row("Frames", &format!("{}", r.frames));
print_row("Encode total", &format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0)); print_row(
print_row("Decode total", &format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0)); "Encode total",
&format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0),
);
print_row(
"Decode total",
&format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0),
);
print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us)); print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us));
print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us)); print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us));
print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec)); print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec));
@@ -41,7 +47,10 @@ fn run_fec(loss_pct: f32) {
print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct)); print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct));
print_row("Source bytes", &format!("{}", r.total_source_bytes)); print_row("Source bytes", &format!("{}", r.total_source_bytes));
print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes)); print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes));
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0)); print_row(
"Total time",
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
);
print_footer(); print_footer();
} }
@@ -49,7 +58,10 @@ fn run_crypto() {
print_header("Crypto (ChaCha20-Poly1305)"); print_header("Crypto (ChaCha20-Poly1305)");
let r = bench::bench_encrypt_decrypt(); let r = bench::bench_encrypt_decrypt();
print_row("Packets", &format!("{}", r.packets)); print_row("Packets", &format!("{}", r.packets));
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0)); print_row(
"Total time",
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
);
print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec)); print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec));
print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec)); print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec));
print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us)); print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us));
@@ -60,9 +72,18 @@ fn run_pipeline() {
print_header("Full Pipeline (E2E)"); print_header("Full Pipeline (E2E)");
let r = bench::bench_full_pipeline(); let r = bench::bench_full_pipeline();
print_row("Frames", &format!("{}", r.frames)); print_row("Frames", &format!("{}", r.frames));
print_row("Encode pipeline", &format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0)); print_row(
print_row("Decode pipeline", &format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0)); "Encode pipeline",
print_row("Avg E2E latency", &format!("{:.1} us/frame", r.avg_e2e_latency_us)); &format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0),
);
print_row(
"Decode pipeline",
&format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0),
);
print_row(
"Avg E2E latency",
&format!("{:.1} us/frame", r.avg_e2e_latency_us),
);
print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in)); print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in));
print_row("Wire out", &format!("{} bytes", r.wire_bytes_out)); print_row("Wire out", &format!("{} bytes", r.wire_bytes_out));
print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio)); print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio));

View File

@@ -165,10 +165,7 @@ pub fn generate_dialer_targets(
// First: all known ports (guaranteed targets) // First: all known ports (guaranteed targets)
for &port in known_ports { for &port in known_ports {
targets.push(SocketAddr::new( targets.push(SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port));
std::net::IpAddr::V4(acceptor_ip),
port,
));
} }
// Fill remaining with random ports (birthday attack) // Fill remaining with random ports (birthday attack)
@@ -178,10 +175,7 @@ pub fn generate_dialer_targets(
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..remaining { for _ in 0..remaining {
let port = rng.gen_range(1024..=65535u16); let port = rng.gen_range(1024..=65535u16);
let addr = SocketAddr::new( let addr = SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port);
std::net::IpAddr::V4(acceptor_ip),
port,
);
if !targets.contains(&addr) { if !targets.contains(&addr) {
targets.push(addr); targets.push(addr);
} }
@@ -339,7 +333,10 @@ mod tests {
fn acceptor_ports_serializes() { fn acceptor_ports_serializes() {
let result = AcceptorPorts { let result = AcceptorPorts {
external_ip: Some(Ipv4Addr::new(203, 0, 113, 5)), external_ip: Some(Ipv4Addr::new(203, 0, 113, 5)),
ports: vec![PortMapping { local_port: 12345, external_port: 54321 }], ports: vec![PortMapping {
local_port: 12345,
external_port: 54321,
}],
attempted: 32, attempted: 32,
succeeded: 1, succeeded: 1,
}; };

View File

@@ -13,11 +13,11 @@ use wzp_codec::{
}; };
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::jitter::{JitterBuffer, PlayoutResult}; use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
use wzp_proto::packet::QualityReport;
use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext}; use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext};
use wzp_proto::quality::AdaptiveQualityController; use wzp_proto::quality::AdaptiveQualityController;
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder}; use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
use wzp_proto::packet::QualityReport; use wzp_proto::{CodecId, MediaType, QualityProfile};
use wzp_proto::{CodecId, QualityProfile};
/// Configuration for a call session. /// Configuration for a call session.
pub struct CallConfig { pub struct CallConfig {
@@ -205,7 +205,7 @@ pub struct CallEncoder {
/// Current profile. /// Current profile.
profile: QualityProfile, profile: QualityProfile,
/// Outbound sequence counter. /// Outbound sequence counter.
seq: u16, seq: u32,
/// Current FEC block. /// Current FEC block.
block_id: u8, block_id: u8,
/// Frame index within current block. /// Frame index within current block.
@@ -318,17 +318,15 @@ impl CallEncoder {
if self.cn_counter % 10 == 0 { if self.cn_counter % 10 == 0 {
let cn_pkt = MediaPacket { let cn_pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::ComfortNoise, codec_id: CodecId::ComfortNoise,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq: self.seq, seq: self.seq,
timestamp: self.timestamp_ms, timestamp: self.timestamp_ms,
fec_block: self.block_id, fec_block: u16::from(self.block_id),
fec_symbol: 0,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(vec![self.cn_level as u8]), payload: Bytes::from(vec![self.cn_level as u8]),
quality_report: None, quality_report: None,
@@ -354,30 +352,31 @@ impl CallEncoder {
// can cleanly identify "no RaptorQ block to assemble" and new // can cleanly identify "no RaptorQ block to assemble" and new
// receivers can short-circuit their FEC ingest path. // receivers can short-circuit their FEC ingest path.
let is_opus = self.profile.codec.is_opus(); let is_opus = self.profile.codec.is_opus();
let (fec_block, fec_symbol, fec_ratio_encoded) = if is_opus { let (fec_block, fec_ratio) = if is_opus {
(0u8, 0u8, 0u8) (0u16, 0u8)
} else { } else {
( (
self.block_id, u16::from(self.block_id) | (u16::from(self.frame_in_block) << 8),
self.frame_in_block,
MediaHeader::encode_fec_ratio(self.profile.fec_ratio), MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
) )
}; };
// Build source media packet // Build source media packet
let mut flags = 0u8;
if self.pending_quality_report.is_some() {
flags |= MediaHeader::FLAG_QUALITY;
}
let source_pkt = MediaPacket { let source_pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags,
media_type: MediaType::Audio,
codec_id: self.profile.codec, codec_id: self.profile.codec,
has_quality_report: self.pending_quality_report.is_some(), stream_id: 0,
fec_ratio_encoded, fec_ratio,
seq: self.seq, seq: self.seq,
timestamp: self.timestamp_ms, timestamp: self.timestamp_ms,
fec_block, fec_block,
fec_symbol,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(encoded.clone()), payload: Bytes::from(encoded.clone()),
quality_report: self.pending_quality_report.take(), quality_report: self.pending_quality_report.take(),
@@ -402,19 +401,15 @@ impl CallEncoder {
for (sym_idx, repair_data) in repairs { for (sym_idx, repair_data) in repairs {
output.push(MediaPacket { output.push(MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: true, flags: MediaHeader::FLAG_REPAIR,
media_type: MediaType::Audio,
codec_id: self.profile.codec, codec_id: self.profile.codec,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: MediaHeader::encode_fec_ratio( fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
self.profile.fec_ratio,
),
seq: self.seq, seq: self.seq,
timestamp: self.timestamp_ms, timestamp: self.timestamp_ms,
fec_block: self.block_id, fec_block: u16::from(self.block_id) | (sym_idx << 8),
fec_symbol: sym_idx,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(repair_data), payload: Bytes::from(repair_data),
quality_report: None, quality_report: None,
@@ -508,7 +503,7 @@ pub struct CallDecoder {
last_good_dred: DredState, last_good_dred: DredState,
/// Sequence number of the packet that produced `last_good_dred`. `None` /// Sequence number of the packet that produced `last_good_dred`. `None`
/// if no packet has yielded DRED state yet (cold start or legacy sender). /// if no packet has yielded DRED state yet (cold start or legacy sender).
last_good_dred_seq: Option<u16>, last_good_dred_seq: Option<u32>,
/// Phase 4 telemetry counter: gaps recovered via DRED reconstruction. /// Phase 4 telemetry counter: gaps recovered via DRED reconstruction.
pub dred_reconstructions: u64, pub dred_reconstructions: u64,
/// Phase 4 telemetry counter: gaps filled via classical Opus PLC /// Phase 4 telemetry counter: gaps filled via classical Opus PLC
@@ -571,8 +566,8 @@ impl CallDecoder {
if !packet.header.codec_id.is_opus() { if !packet.header.codec_id.is_opus() {
let _ = self.fec_dec.add_symbol( let _ = self.fec_dec.add_symbol(
packet.header.fec_block, packet.header.fec_block,
packet.header.fec_symbol, packet.header.fec_block >> 8,
packet.header.is_repair, packet.header.is_repair(),
&packet.payload, &packet.payload,
); );
} }
@@ -582,7 +577,7 @@ impl CallDecoder {
// swap with the cached `last_good_dred` so later gap reconstruction // swap with the cached `last_good_dred` so later gap reconstruction
// has fresh neural redundancy to draw from. Parsing happens before // has fresh neural redundancy to draw from. Parsing happens before
// the jitter push because the jitter buffer consumes the packet. // the jitter push because the jitter buffer consumes the packet.
if packet.header.codec_id.is_opus() && !packet.header.is_repair { if packet.header.codec_id.is_opus() && !packet.header.is_repair() {
match self match self
.dred_decoder .dred_decoder
.parse_into(&mut self.dred_parse_scratch, &packet.payload) .parse_into(&mut self.dred_parse_scratch, &packet.payload)
@@ -611,7 +606,7 @@ impl CallDecoder {
// Source packets (Opus or Codec2) go to the jitter buffer for decode. // Source packets (Opus or Codec2) go to the jitter buffer for decode.
// Repair packets never reach the jitter buffer; for Codec2 they're // Repair packets never reach the jitter buffer; for Codec2 they're
// used by the FEC decoder above, for Opus they're dropped here. // used by the FEC decoder above, for Opus they're dropped here.
if !packet.header.is_repair { if !packet.header.is_repair() {
self.jitter.push(packet); self.jitter.push(packet);
} }
} }
@@ -646,6 +641,7 @@ impl CallDecoder {
fec_ratio: 0.3, fec_ratio: 0.3,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
..QualityProfile::GOOD
}, },
CodecId::Opus6k => QualityProfile::DEGRADED, CodecId::Opus6k => QualityProfile::DEGRADED,
CodecId::Opus32k => QualityProfile::STUDIO_32K, CodecId::Opus32k => QualityProfile::STUDIO_32K,
@@ -656,9 +652,13 @@ impl CallDecoder {
fec_ratio: 0.5, fec_ratio: 0.5,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
..QualityProfile::GOOD
}, },
CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC, CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC,
CodecId::ComfortNoise => QualityProfile::GOOD, CodecId::ComfortNoise => QualityProfile::GOOD,
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => {
panic!("video codec passed to audio decoder")
}
} }
} }
@@ -711,12 +711,12 @@ impl CallDecoder {
if let Some(last_seq) = self.last_good_dred_seq { if let Some(last_seq) = self.last_good_dred_seq {
// How many frames ahead of the missing seq is the // How many frames ahead of the missing seq is the
// last-good packet? Use wrapping arithmetic for the // last-good packet? Use wrapping arithmetic for the
// u16 seq space. // u32 seq space.
let seq_delta = last_seq.wrapping_sub(seq); let seq_delta = last_seq.wrapping_sub(seq);
// Reject stale or backward state. u16 wraparound // Reject stale or backward state. u32 wraparound
// would make a "seq went backward" delta very large; // would make a "seq went backward" delta very large;
// cap at a sane forward-looking window. // cap at a sane forward-looking window.
const MAX_SEQ_DELTA: u16 = 128; const MAX_SEQ_DELTA: u32 = 128;
if seq_delta > 0 && seq_delta <= MAX_SEQ_DELTA { if seq_delta > 0 && seq_delta <= MAX_SEQ_DELTA {
let frame_samples = let frame_samples =
(48_000 * self.profile.frame_duration_ms as i32) / 1000; (48_000 * self.profile.frame_duration_ms as i32) / 1000;
@@ -785,7 +785,7 @@ impl CallDecoder {
/// Phase 3b introspection: sequence number of the most recently parsed /// Phase 3b introspection: sequence number of the most recently parsed
/// valid DRED state, or `None` if no Opus packet has yielded DRED data /// valid DRED state, or `None` if no Opus packet has yielded DRED data
/// yet. Used by tests to debug reconstruction eligibility. /// yet. Used by tests to debug reconstruction eligibility.
pub fn last_good_dred_seq(&self) -> Option<u16> { pub fn last_good_dred_seq(&self) -> Option<u32> {
self.last_good_dred_seq self.last_good_dred_seq
} }
@@ -852,7 +852,7 @@ mod tests {
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
assert!(!packets.is_empty()); assert!(!packets.is_empty());
assert_eq!(packets[0].header.seq, 0); assert_eq!(packets[0].header.seq, 0);
assert!(!packets[0].header.is_repair); assert!(!packets[0].header.is_repair());
} }
/// Phase 2: Opus packets have zero FEC header fields — no block, no /// Phase 2: Opus packets have zero FEC header fields — no block, no
@@ -875,10 +875,9 @@ mod tests {
assert_eq!(packets.len(), 1, "Opus must emit exactly 1 source packet"); assert_eq!(packets.len(), 1, "Opus must emit exactly 1 source packet");
let hdr = &packets[0].header; let hdr = &packets[0].header;
assert!(hdr.codec_id.is_opus()); assert!(hdr.codec_id.is_opus());
assert!(!hdr.is_repair); assert!(!hdr.is_repair());
assert_eq!(hdr.fec_block, 0, "Opus fec_block must be 0"); assert_eq!(hdr.fec_block, 0, "Opus fec_block must be 0");
assert_eq!(hdr.fec_symbol, 0, "Opus fec_symbol must be 0"); assert_eq!(hdr.fec_ratio, 0, "Opus fec_ratio must be 0");
assert_eq!(hdr.fec_ratio_encoded, 0, "Opus fec_ratio_encoded must be 0");
} }
/// Phase 2: Opus never emits repair packets, regardless of how many /// Phase 2: Opus never emits repair packets, regardless of how many
@@ -902,7 +901,7 @@ mod tests {
for _ in 0..20 { for _ in 0..20 {
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
total_packets += packets.len(); total_packets += packets.len();
repair_count += packets.iter().filter(|p| p.header.is_repair).count(); repair_count += packets.iter().filter(|p| p.header.is_repair()).count();
} }
assert_eq!(repair_count, 0, "Opus must emit zero repair packets"); assert_eq!(repair_count, 0, "Opus must emit zero repair packets");
assert_eq!( assert_eq!(
@@ -934,7 +933,7 @@ mod tests {
for _ in 0..16 { for _ in 0..16 {
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
for p in &packets { for p in &packets {
if p.header.is_repair { if p.header.is_repair() {
repair_count += 1; repair_count += 1;
} }
} }
@@ -953,17 +952,15 @@ mod tests {
let pkt = MediaPacket { let pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k, codec_id: CodecId::Opus24k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq: 0, seq: 0,
timestamp: 0, timestamp: 0,
fec_block: 0, fec_block: 0,
fec_symbol: 0,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(vec![0u8; 60]), payload: Bytes::from(vec![0u8; 60]),
quality_report: None, quality_report: None,
@@ -1025,17 +1022,15 @@ mod tests {
encoded.truncate(n); encoded.truncate(n);
let pkt = MediaPacket { let pkt = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k, codec_id: CodecId::Opus24k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq: i, seq: i as u32,
timestamp: (i as u32) * 20, timestamp: (i as u32) * 20,
fec_block: 0, fec_block: 0,
fec_symbol: 0,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(encoded), payload: Bytes::from(encoded),
quality_report: None, quality_report: None,
@@ -1105,9 +1100,7 @@ mod tests {
let dred_delta = dec.dred_reconstructions - baseline_dred; let dred_delta = dec.dred_reconstructions - baseline_dred;
let plc_delta = dec.classical_plc_invocations - baseline_plc; let plc_delta = dec.classical_plc_invocations - baseline_plc;
eprintln!( eprintln!("[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}");
"[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}"
);
assert!( assert!(
dred_delta >= 1, dred_delta >= 1,
"expected ≥1 DRED reconstruction on single-packet loss, \ "expected ≥1 DRED reconstruction on single-packet loss, \
@@ -1168,7 +1161,7 @@ mod tests {
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
for pkt in packets { for pkt in packets {
// Drop every 5th source packet to simulate loss. // Drop every 5th source packet to simulate loss.
if !pkt.header.is_repair && i % 5 == 3 { if !pkt.header.is_repair() && i % 5 == 3 {
continue; continue;
} }
dec.ingest(pkt); dec.ingest(pkt);
@@ -1322,20 +1315,18 @@ mod tests {
// ---- JitterStats telemetry tests ---- // ---- JitterStats telemetry tests ----
fn make_test_packet(seq: u16) -> MediaPacket { fn make_test_packet(seq: u32) -> MediaPacket {
MediaPacket { MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k, codec_id: CodecId::Opus24k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq, seq,
timestamp: seq as u32 * 20, timestamp: seq * 20,
fec_block: 0, fec_block: 0,
fec_symbol: seq as u8,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(vec![0u8; 60]), payload: Bytes::from(vec![0u8; 60]),
quality_report: None, quality_report: None,
@@ -1347,7 +1338,7 @@ mod tests {
let config = CallConfig::default(); let config = CallConfig::default();
let mut dec = CallDecoder::new(&config); let mut dec = CallDecoder::new(&config);
for i in 0..5u16 { for i in 0..5u32 {
dec.ingest(make_test_packet(i)); dec.ingest(make_test_packet(i));
} }
@@ -1377,7 +1368,7 @@ mod tests {
let mut dec = CallDecoder::new(&config); let mut dec = CallDecoder::new(&config);
// Generate some stats: ingest packets and trigger underruns on empty buffer // Generate some stats: ingest packets and trigger underruns on empty buffer
for i in 0..3u16 { for i in 0..3u32 {
dec.ingest(make_test_packet(i)); dec.ingest(make_test_packet(i));
} }
// Also call decode on empty decoder to get underruns // Also call decode on empty decoder to get underruns
@@ -1456,10 +1447,7 @@ mod tests {
cn_packets >= 1, cn_packets >= 1,
"should have at least one CN packet, got {cn_packets}" "should have at least one CN packet, got {cn_packets}"
); );
assert!( assert!(enc.frames_suppressed > 0, "frames_suppressed should be > 0");
enc.frames_suppressed > 0,
"frames_suppressed should be > 0"
);
} }
// ---- DredTuner integration tests ---- // ---- DredTuner integration tests ----
@@ -1506,7 +1494,10 @@ mod tests {
// Verify the encoder still works after tuning. // Verify the encoder still works after tuning.
let pcm = voice_frame_20ms(0); let pcm = voice_frame_20ms(0);
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
assert!(!packets.is_empty(), "encoder must still produce packets after DRED tuning"); assert!(
!packets.is_empty(),
"encoder must still produce packets after DRED tuning"
);
} }
/// DredTuner jitter spike triggers pre-emptive DRED boost to ceiling. /// DredTuner jitter spike triggers pre-emptive DRED boost to ceiling.
@@ -1524,11 +1515,15 @@ mod tests {
// Jitter spikes to 40ms (8x baseline of ~5ms). // Jitter spikes to 40ms (8x baseline of ~5ms).
let tuning = tuner.update(0.0, 50, 40); let tuning = tuner.update(0.0, 50, 40);
assert!(tuner.spike_boost_active(), "jitter spike should activate boost"); assert!(
tuner.spike_boost_active(),
"jitter spike should activate boost"
);
assert!(tuning.is_some()); assert!(tuning.is_some());
// Ceiling for Opus24k is 50 frames = 500 ms. // Ceiling for Opus24k is 50 frames = 500 ms.
assert_eq!( assert_eq!(
tuning.unwrap().dred_frames, 50, tuning.unwrap().dred_frames,
50,
"spike should push to ceiling" "spike should push to ceiling"
); );
} }
@@ -1604,12 +1599,73 @@ mod tests {
let pcm = voice_frame_20ms(0); let pcm = voice_frame_20ms(0);
let packets = enc.encode_frame(&pcm).unwrap(); let packets = enc.encode_frame(&pcm).unwrap();
assert!(!packets.is_empty()); assert!(!packets.is_empty());
assert!(packets[0].header.has_quality_report, "first packet should have quality report"); assert!(
packets[0].header.has_quality(),
"first packet should have quality report"
);
assert!(packets[0].quality_report.is_some()); assert!(packets[0].quality_report.is_some());
// Next frame should NOT have quality_report (it was consumed) // Next frame should NOT have quality_report (it was consumed)
let packets2 = enc.encode_frame(&voice_frame_20ms(960)).unwrap(); let packets2 = enc.encode_frame(&voice_frame_20ms(960)).unwrap();
assert!(!packets2[0].header.has_quality_report, "second packet should not have quality report"); assert!(
!packets2[0].header.has_quality(),
"second packet should not have quality report"
);
assert!(packets2[0].quality_report.is_none()); assert!(packets2[0].quality_report.is_none());
} }
#[test]
fn quality_report_aead_tamper_fails_decrypt() {
use wzp_crypto::ChaChaSession;
use wzp_proto::CryptoSession;
// Build a packet with a QualityReport trailer.
let pkt = MediaPacket {
header: MediaHeader {
version: 2,
flags: MediaHeader::FLAG_QUALITY,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 10,
seq: 42,
timestamp: 1000,
fec_block: 0,
},
payload: Bytes::from(vec![0xAB; 60]),
quality_report: Some(QualityReport::from_path_stats(5.0, 80, 10)),
};
// Serialize: header || payload || quality_report
let wire = pkt.to_bytes();
assert_eq!(
wire.len(),
MediaHeader::WIRE_SIZE + pkt.payload.len() + QualityReport::WIRE_SIZE
);
let header_bytes = &wire[..MediaHeader::WIRE_SIZE];
let plaintext = &wire[MediaHeader::WIRE_SIZE..];
// Encrypt with ChaCha20-Poly1305 (header as AAD, payload+QR as plaintext).
let mut alice = ChaChaSession::new([0xAA; 32]);
let mut bob = ChaChaSession::new([0xAA; 32]);
let mut ciphertext = Vec::new();
alice
.encrypt(header_bytes, plaintext, &mut ciphertext)
.unwrap();
// Tamper with a byte in the QualityReport region (last 4 bytes of plaintext
// → last 4 bytes of ciphertext for ChaCha20 stream cipher).
let qr_offset_in_plaintext = plaintext.len() - QualityReport::WIRE_SIZE;
let tamper_idx = qr_offset_in_plaintext;
ciphertext[tamper_idx] ^= 0xFF;
// Decryption must fail because the AEAD tag no longer matches.
let mut decrypted = Vec::new();
let result = bob.decrypt(header_bytes, &ciphertext, &mut decrypted);
assert!(
result.is_err(),
"tampering with QualityReport inside AEAD payload must cause decryption failure"
);
}
} }

View File

@@ -17,7 +17,7 @@ use std::sync::Arc;
use tracing::{error, info}; use tracing::{error, info};
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder}; use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
use wzp_proto::MediaTransport; use wzp_proto::{MediaTransport, default_signal_version};
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
@@ -108,7 +108,11 @@ fn parse_args() -> CliArgs {
"--signal" => signal = true, "--signal" => signal = true,
"--call" => { "--call" => {
i += 1; i += 1;
call_target = Some(args.get(i).expect("--call requires a fingerprint").to_string()); call_target = Some(
args.get(i)
.expect("--call requires a fingerprint")
.to_string(),
);
} }
"--send-tone" => { "--send-tone" => {
i += 1; i += 1;
@@ -185,8 +189,12 @@ fn parse_args() -> CliArgs {
); );
} }
"--sweep" => sweep = true, "--sweep" => sweep = true,
"--netcheck" => { netcheck = true; } "--netcheck" => {
"--version-check" => { version_check = true; } netcheck = true;
}
"--version-check" => {
version_check = true;
}
"--help" | "-h" => { "--help" | "-h" => {
eprintln!("Usage: wzp-client [options] [relay-addr]"); eprintln!("Usage: wzp-client [options] [relay-addr]");
eprintln!(); eprintln!();
@@ -197,13 +205,19 @@ fn parse_args() -> CliArgs {
eprintln!(" --record <file.raw> Record received audio to raw PCM file"); eprintln!(" --record <file.raw> Record received audio to raw PCM file");
eprintln!(" --echo-test <secs> Run automated echo quality test"); eprintln!(" --echo-test <secs> Run automated echo quality test");
eprintln!(" --drift-test <secs> Run automated clock-drift measurement"); eprintln!(" --drift-test <secs> Run automated clock-drift measurement");
eprintln!(" --sweep Run jitter buffer parameter sweep (local, no network)"); eprintln!(
eprintln!(" --seed <hex> Identity seed (64 hex chars, featherChat compatible)"); " --sweep Run jitter buffer parameter sweep (local, no network)"
);
eprintln!(
" --seed <hex> Identity seed (64 hex chars, featherChat compatible)"
);
eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)"); eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)");
eprintln!(" --room <name> Room name (hashed for privacy before sending)"); eprintln!(" --room <name> Room name (hashed for privacy before sending)");
eprintln!(" --token <token> featherChat bearer token for relay auth"); eprintln!(" --token <token> featherChat bearer token for relay auth");
eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)"); eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)");
eprintln!(" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)"); eprintln!(
" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)"
);
eprintln!(); eprintln!();
eprintln!("Default relay: 127.0.0.1:4433"); eprintln!("Default relay: 127.0.0.1:4433");
std::process::exit(0); std::process::exit(0);
@@ -265,9 +279,7 @@ async fn main() -> anyhow::Result<()> {
if cli.netcheck { if cli.netcheck {
let config = wzp_client::netcheck::NetcheckConfig { let config = wzp_client::netcheck::NetcheckConfig {
stun_config: wzp_client::stun::StunConfig::default(), stun_config: wzp_client::stun::StunConfig::default(),
relays: vec![ relays: vec![("relay".into(), cli.relay_addr)],
("relay".into(), cli.relay_addr),
],
timeout: std::time::Duration::from_secs(5), timeout: std::time::Duration::from_secs(5),
test_portmap: true, test_portmap: true,
test_ipv6: true, test_ipv6: true,
@@ -283,7 +295,8 @@ async fn main() -> anyhow::Result<()> {
let client_config = wzp_transport::client_config(); let client_config = wzp_transport::client_config();
let bind_addr: SocketAddr = "0.0.0.0:0".parse()?; let bind_addr: SocketAddr = "0.0.0.0:0".parse()?;
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?; let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
let conn = wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?; let conn =
wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
match conn.accept_uni().await { match conn.accept_uni().await {
Ok(mut recv) => { Ok(mut recv) => {
let data = recv.read_to_end(256).await.unwrap_or_default(); let data = recv.read_to_end(256).await.unwrap_or_default();
@@ -291,7 +304,10 @@ async fn main() -> anyhow::Result<()> {
println!("{} {}", cli.relay_addr, version.trim()); println!("{} {}", cli.relay_addr, version.trim());
} }
Err(e) => { Err(e) => {
eprintln!("relay {} does not support version query: {e}", cli.relay_addr); eprintln!(
"relay {} does not support version query: {e}",
cli.relay_addr
);
} }
} }
endpoint.close(0u32.into(), b"done"); endpoint.close(0u32.into(), b"done");
@@ -331,8 +347,7 @@ async fn main() -> anyhow::Result<()> {
"0.0.0.0:0".parse()? "0.0.0.0:0".parse()?
}; };
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?; let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
let connection = let connection = wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
info!("Connected to relay"); info!("Connected to relay");
@@ -343,9 +358,11 @@ async fn main() -> anyhow::Result<()> {
{ {
let shutdown_transport = transport.clone(); let shutdown_transport = transport.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM handler"); .expect("failed to register SIGTERM handler");
let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) let mut sigint =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
.expect("failed to register SIGINT handler"); .expect("failed to register SIGINT handler");
tokio::select! { tokio::select! {
_ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); } _ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); }
@@ -354,13 +371,16 @@ async fn main() -> anyhow::Result<()> {
// Close the QUIC connection immediately (APPLICATION_CLOSE frame). // Close the QUIC connection immediately (APPLICATION_CLOSE frame).
// Don't call process::exit — let the main task detect the closed // Don't call process::exit — let the main task detect the closed
// connection and perform clean shutdown (e.g., save recordings). // connection and perform clean shutdown (e.g., save recordings).
shutdown_transport.connection().close(0u32.into(), b"shutdown"); shutdown_transport
.connection()
.close(0u32.into(), b"shutdown");
}); });
} }
// Send auth token if provided (relay with --auth-url expects this first) // Send auth token if provided (relay with --auth-url expects this first)
if let Some(ref token) = cli.token { if let Some(ref token) = cli.token {
let auth = wzp_proto::SignalMessage::AuthToken { let auth = wzp_proto::SignalMessage::AuthToken {
version: default_signal_version(),
token: token.clone(), token: token.clone(),
}; };
transport.send_signal(&auth).await?; transport.send_signal(&auth).await?;
@@ -368,21 +388,29 @@ async fn main() -> anyhow::Result<()> {
} }
// Crypto handshake — establishes verified identity + session key // Crypto handshake — establishes verified identity + session key
let _crypto_session = wzp_client::handshake::perform_handshake( let hs = wzp_client::handshake::perform_handshake(
&*transport, &*transport,
&seed.0, &seed.0,
None, // alias — desktop client doesn't set one yet None, // alias — desktop client doesn't set one yet
).await?; )
info!("crypto handshake complete"); .await?;
info!(video_codec = ?hs.video_codec, "crypto handshake complete");
// Wrap the transport so all media I/O goes through AEAD encryption.
let enc_transport: Arc<dyn wzp_proto::MediaTransport> = Arc::new(
wzp_client::encrypted_transport::EncryptingTransport::new(transport.clone(), hs.session),
);
if cli.live { if cli.live {
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
{ {
return run_live(transport).await; return run_live(enc_transport).await;
} }
#[cfg(not(feature = "audio"))] #[cfg(not(feature = "audio"))]
{ {
anyhow::bail!("--live requires the 'audio' feature (build with: cargo build --features audio)"); anyhow::bail!(
"--live requires the 'audio' feature (build with: cargo build --features audio)"
);
} }
} else if let Some(secs) = cli.echo_test_secs { } else if let Some(secs) = cli.echo_test_secs {
let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?; let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?;
@@ -399,14 +427,20 @@ async fn main() -> anyhow::Result<()> {
transport.close().await?; transport.close().await?;
Ok(()) Ok(())
} else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() { } else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() {
run_file_mode(transport, cli.send_tone_secs, cli.send_file, cli.record_file).await run_file_mode(
enc_transport,
cli.send_tone_secs,
cli.send_file,
cli.record_file,
)
.await
} else { } else {
run_silence(transport).await run_silence(enc_transport).await
} }
} }
/// Send silence frames (connectivity test). /// Send silence frames (connectivity test).
async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> { async fn run_silence(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
let config = CallConfig::default(); let config = CallConfig::default();
let mut encoder = CallEncoder::new(&config); let mut encoder = CallEncoder::new(&config);
@@ -420,7 +454,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
for i in 0..250u32 { for i in 0..250u32 {
let packets = encoder.encode_frame(&pcm)?; let packets = encoder.encode_frame(&pcm)?;
for pkt in &packets { for pkt in &packets {
if pkt.header.is_repair { if pkt.header.is_repair() {
total_repair += 1; total_repair += 1;
} else { } else {
total_source += 1; total_source += 1;
@@ -445,6 +479,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
info!(total_source, total_repair, total_bytes, "done — closing"); info!(total_source, total_repair, total_bytes, "done — closing");
let hangup = wzp_proto::SignalMessage::Hangup { let hangup = wzp_proto::SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}; };
@@ -455,7 +490,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
/// File/tone mode: send a test tone or audio file, and/or record received audio. /// File/tone mode: send a test tone or audio file, and/or record received audio.
async fn run_file_mode( async fn run_file_mode(
transport: Arc<wzp_transport::QuinnTransport>, transport: Arc<dyn wzp_proto::MediaTransport>,
send_tone_secs: Option<u32>, send_tone_secs: Option<u32>,
send_file: Option<String>, send_file: Option<String>,
record_file: Option<String>, record_file: Option<String>,
@@ -470,21 +505,28 @@ async fn run_file_mode(
// Read raw PCM file (48kHz mono s16le) // Read raw PCM file (48kHz mono s16le)
let bytes = match std::fs::read(path) { let bytes = match std::fs::read(path) {
Ok(b) => b, Ok(b) => b,
Err(e) => { error!("read {path}: {e}"); return; } Err(e) => {
error!("read {path}: {e}");
return;
}
}; };
let samples: Vec<i16> = bytes.chunks_exact(2) let samples: Vec<i16> = bytes
.chunks_exact(2)
.map(|c| i16::from_le_bytes([c[0], c[1]])) .map(|c| i16::from_le_bytes([c[0], c[1]]))
.collect(); .collect();
let duration = samples.len() as f64 / 48_000.0; let duration = samples.len() as f64 / 48_000.0;
info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file"); info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file");
samples.chunks(FRAME_SAMPLES) samples
.chunks(FRAME_SAMPLES)
.filter(|c| c.len() == FRAME_SAMPLES) .filter(|c| c.len() == FRAME_SAMPLES)
.map(|c| c.to_vec()) .map(|c| c.to_vec())
.collect() .collect()
} else if let Some(secs) = send_tone_secs { } else if let Some(secs) = send_tone_secs {
let total = (secs as u64) * 50; let total = (secs as u64) * 50;
info!(seconds = secs, frames = total, "sending 440Hz tone"); info!(seconds = secs, frames = total, "sending 440Hz tone");
(0..total).map(|i| generate_sine_frame(440.0, 48_000, i)).collect() (0..total)
.map(|i| generate_sine_frame(440.0, 48_000, i))
.collect()
} else { } else {
// No sending, just wait // No sending, just wait
tokio::signal::ctrl_c().await.ok(); tokio::signal::ctrl_c().await.ok();
@@ -508,7 +550,7 @@ async fn run_file_mode(
} }
}; };
for pkt in &packets { for pkt in &packets {
if pkt.header.is_repair { if pkt.header.is_repair() {
total_repair += 1; total_repair += 1;
} else { } else {
total_source += 1; total_source += 1;
@@ -556,7 +598,7 @@ async fn run_file_mode(
result = recv_transport.recv_media() => { result = recv_transport.recv_media() => {
match result { match result {
Ok(Some(pkt)) => { Ok(Some(pkt)) => {
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
if !is_repair { if !is_repair {
if let Some(n) = decoder.decode_next(&mut pcm_buf) { if let Some(n) = decoder.decode_next(&mut pcm_buf) {
@@ -597,6 +639,7 @@ async fn run_file_mode(
// Send Hangup signal so the relay knows we're done // Send Hangup signal so the relay knows we're done
let hangup = wzp_proto::SignalMessage::Hangup { let hangup = wzp_proto::SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}; };
@@ -636,7 +679,7 @@ async fn run_file_mode(
/// Live mode: capture from mic, encode, send; receive, decode, play. /// Live mode: capture from mic, encode, send; receive, decode, play.
#[cfg(feature = "audio")] #[cfg(feature = "audio")]
async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> { async fn run_live(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
use wzp_client::audio_io::{AudioCapture, AudioPlayback}; use wzp_client::audio_io::{AudioCapture, AudioPlayback};
let capture = AudioCapture::start()?; let capture = AudioCapture::start()?;
@@ -689,7 +732,7 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
loop { loop {
match recv_transport.recv_media().await { match recv_transport.recv_media().await {
Ok(Some(pkt)) => { Ok(Some(pkt)) => {
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
// Only decode for source packets (1 source = 1 audio frame). // Only decode for source packets (1 source = 1 audio frame).
// Repair packets feed the FEC decoder but don't produce audio. // Repair packets feed the FEC decoder but don't produce audio.
@@ -734,7 +777,7 @@ async fn run_signal_mode(
token: Option<String>, token: Option<String>,
call_target: Option<String>, call_target: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use wzp_proto::SignalMessage; use wzp_proto::{SignalMessage, default_signal_version};
let identity = seed.derive_identity(); let identity = seed.derive_identity();
let pub_id = identity.public_identity(); let pub_id = identity.public_identity();
@@ -756,22 +799,34 @@ async fn run_signal_mode(
// Auth if token provided // Auth if token provided
if let Some(ref tok) = token { if let Some(ref tok) = token {
transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await?; transport
.send_signal(&SignalMessage::AuthToken {
version: default_signal_version(),
token: tok.clone(),
})
.await?;
} }
// Register presence (signature not verified in Phase 1) // Register presence (signature not verified in Phase 1)
transport.send_signal(&SignalMessage::RegisterPresence { transport
.send_signal(&SignalMessage::RegisterPresence {
version: default_signal_version(),
identity_pub, identity_pub,
signature: vec![], // Phase 1: not verified signature: vec![], // Phase 1: not verified
alias: None, alias: None,
}).await?; })
.await?;
// Wait for ack // Wait for ack
match transport.recv_signal().await? { match transport.recv_signal().await? {
Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => { Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => {
info!(fingerprint = %fp, "registered on relay — waiting for calls"); info!(fingerprint = %fp, "registered on relay — waiting for calls");
} }
Some(SignalMessage::RegisterPresenceAck { success: false, error, .. }) => { Some(SignalMessage::RegisterPresenceAck {
success: false,
error,
..
}) => {
anyhow::bail!("registration failed: {}", error.unwrap_or_default()); anyhow::bail!("registration failed: {}", error.unwrap_or_default());
} }
other => { other => {
@@ -782,10 +837,17 @@ async fn run_signal_mode(
// If --call specified, place the call // If --call specified, place the call
if let Some(ref target) = call_target { if let Some(ref target) = call_target {
info!(target = %target, "placing direct call..."); info!(target = %target, "placing direct call...");
let call_id = format!("{:016x}", std::time::SystemTime::now() let call_id = format!(
.duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()); "{:016x}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
);
transport.send_signal(&SignalMessage::DirectCallOffer { transport
.send_signal(&SignalMessage::DirectCallOffer {
version: default_signal_version(),
caller_fingerprint: fp.clone(), caller_fingerprint: fp.clone(),
caller_alias: None, caller_alias: None,
target_fingerprint: target.clone(), target_fingerprint: target.clone(),
@@ -800,7 +862,8 @@ async fn run_signal_mode(
caller_local_addrs: Vec::new(), caller_local_addrs: Vec::new(),
caller_mapped_addr: None, caller_mapped_addr: None,
caller_build_version: None, caller_build_version: None,
}).await?; })
.await?;
} }
// Signal recv loop — handle incoming signals // Signal recv loop — handle incoming signals
@@ -811,10 +874,15 @@ async fn run_signal_mode(
loop { loop {
match signal_transport.recv_signal().await { match signal_transport.recv_signal().await {
Ok(Some(msg)) => match msg { Ok(Some(msg)) => match msg {
SignalMessage::CallRinging { call_id } => { SignalMessage::CallRinging { call_id, .. } => {
info!(call_id = %call_id, "ringing..."); info!(call_id = %call_id, "ringing...");
} }
SignalMessage::DirectCallOffer { caller_fingerprint, caller_alias, call_id, .. } => { SignalMessage::DirectCallOffer {
caller_fingerprint,
caller_alias,
call_id,
..
} => {
info!( info!(
from = %caller_fingerprint, from = %caller_fingerprint,
alias = ?caller_alias, alias = ?caller_alias,
@@ -822,7 +890,9 @@ async fn run_signal_mode(
"incoming call — auto-accepting (generic)" "incoming call — auto-accepting (generic)"
); );
// Auto-accept for CLI testing // Auto-accept for CLI testing
let _ = signal_transport.send_signal(&SignalMessage::DirectCallAnswer { let _ = signal_transport
.send_signal(&SignalMessage::DirectCallAnswer {
version: default_signal_version(),
call_id, call_id,
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric, accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
identity_pub: Some(identity_pub), identity_pub: Some(identity_pub),
@@ -835,12 +905,25 @@ async fn run_signal_mode(
callee_local_addrs: Vec::new(), callee_local_addrs: Vec::new(),
callee_mapped_addr: None, callee_mapped_addr: None,
callee_build_version: None, callee_build_version: None,
}).await; })
.await;
} }
SignalMessage::DirectCallAnswer { call_id, accept_mode, .. } => { SignalMessage::DirectCallAnswer {
call_id,
accept_mode,
..
} => {
info!(call_id = %call_id, mode = ?accept_mode, "call answered"); info!(call_id = %call_id, mode = ?accept_mode, "call answered");
} }
SignalMessage::CallSetup { call_id, room, relay_addr: setup_relay, peer_direct_addr: _, peer_local_addrs: _, peer_mapped_addr: _ } => { SignalMessage::CallSetup {
call_id,
room,
relay_addr: setup_relay,
peer_direct_addr: _,
peer_local_addrs: _,
peer_mapped_addr: _,
..
} => {
info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room"); info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room");
// Connect to the media room // Connect to the media room
@@ -848,18 +931,28 @@ async fn run_signal_mode(
let media_cfg = wzp_transport::client_config(); let media_cfg = wzp_transport::client_config();
match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await { match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await {
Ok(media_conn) => { Ok(media_conn) => {
let media_transport = Arc::new(wzp_transport::QuinnTransport::new(media_conn)); let media_transport =
Arc::new(wzp_transport::QuinnTransport::new(media_conn));
// Crypto handshake // Crypto handshake
match wzp_client::handshake::perform_handshake(&*media_transport, &my_seed, None).await { match wzp_client::handshake::perform_handshake(
Ok(_session) => { &*media_transport,
info!("media connected — sending tone (press Ctrl+C to hang up)"); &my_seed,
None,
)
.await
{
Ok(_hs) => {
info!(
"media connected — sending tone (press Ctrl+C to hang up)"
);
// Simple tone sender for testing // Simple tone sender for testing
let mt = media_transport.clone(); let mt = media_transport.clone();
let send_task = tokio::spawn(async move { let send_task = tokio::spawn(async move {
let config = wzp_client::call::CallConfig::default(); let config = wzp_client::call::CallConfig::default();
let mut encoder = wzp_client::call::CallEncoder::new(&config); let mut encoder =
wzp_client::call::CallEncoder::new(&config);
let duration = tokio::time::Duration::from_millis(20); let duration = tokio::time::Duration::from_millis(20);
loop { loop {
let pcm: Vec<i16> = (0..FRAME_SAMPLES) let pcm: Vec<i16> = (0..FRAME_SAMPLES)
@@ -867,7 +960,9 @@ async fn run_signal_mode(
.collect(); .collect();
if let Ok(pkts) = encoder.encode_frame(&pcm) { if let Ok(pkts) = encoder.encode_frame(&pcm) {
for pkt in &pkts { for pkt in &pkts {
if mt.send_media(pkt).await.is_err() { return; } if mt.send_media(pkt).await.is_err() {
return;
}
} }
} }
tokio::time::sleep(duration).await; tokio::time::sleep(duration).await;
@@ -890,6 +985,7 @@ async fn run_signal_mode(
_ = tokio::signal::ctrl_c() => { _ = tokio::signal::ctrl_c() => {
info!("hanging up..."); info!("hanging up...");
let _ = signal_transport.send_signal(&SignalMessage::Hangup { let _ = signal_transport.send_signal(&SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}).await; }).await;

View File

@@ -144,7 +144,7 @@ pub async fn run_drift_test(
} }
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await { match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
Ok(Ok(Some(pkt))) => { Ok(Ok(Some(pkt))) => {
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
if !is_repair { if !is_repair {
if let Some(_n) = decoder.decode_next(&mut pcm_buf) { if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
@@ -180,7 +180,7 @@ pub async fn run_drift_test(
while Instant::now() < drain_deadline { while Instant::now() < drain_deadline {
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await { match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
Ok(Ok(Some(pkt))) => { Ok(Ok(Some(pkt))) => {
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
if !is_repair { if !is_repair {
if let Some(_n) = decoder.decode_next(&mut pcm_buf) { if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
@@ -234,7 +234,10 @@ pub fn print_drift_report(result: &DriftResult) {
println!(); println!();
println!("Expected duration: {} ms", result.expected_duration_ms); println!("Expected duration: {} ms", result.expected_duration_ms);
println!("Actual duration: {} ms", result.actual_duration_ms); println!("Actual duration: {} ms", result.actual_duration_ms);
println!("Drift: {} ms ({:+.4}%)", result.drift_ms, result.drift_pct); println!(
"Drift: {} ms ({:+.4}%)",
result.drift_ms, result.drift_pct
);
println!(); println!();
// Interpretation // Interpretation
@@ -246,9 +249,15 @@ pub fn print_drift_report(result: &DriftResult) {
} else if abs_drift < 20 { } else if abs_drift < 20 {
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms)."); println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
} else if abs_drift < 100 { } else if abs_drift < 100 {
println!("Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.", abs_drift); println!(
"Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.",
abs_drift
);
} else { } else {
println!("Result: POOR -- significant drift ({} ms). Investigate clock sources.", abs_drift); println!(
"Result: POOR -- significant drift ({} ms). Investigate clock sources.",
abs_drift
);
} }
println!(); println!();
} }

View File

@@ -299,10 +299,16 @@ pub async fn race(
socket2::Domain::IPV4, socket2::Domain::IPV4,
socket2::Type::DGRAM, socket2::Type::DGRAM,
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
).map_err(|e| format!("socket: {e}"))?; )
sock.set_reuse_address(true).map_err(|e| format!("reuseaddr: {e}"))?; .map_err(|e| format!("socket: {e}"))?;
sock.set_reuse_address(true)
.map_err(|e| format!("reuseaddr: {e}"))?;
// macOS/BSD/Linux also need SO_REUSEPORT // macOS/BSD/Linux also need SO_REUSEPORT
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "android"))] #[cfg(any(
target_os = "macos",
target_os = "linux",
target_os = "android"
))]
{ {
// socket2 exposes set_reuse_port on unix // socket2 exposes set_reuse_port on unix
unsafe { unsafe {
@@ -316,12 +322,14 @@ pub async fn race(
); );
} }
} }
sock.set_nonblocking(true).map_err(|e| format!("nonblock: {e}"))?; sock.set_nonblocking(true)
.map_err(|e| format!("nonblock: {e}"))?;
let bind_addr: SocketAddr = SocketAddr::new( let bind_addr: SocketAddr = SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
local_addr.port(), local_addr.port(),
); );
sock.bind(&bind_addr.into()).map_err(|e| format!("bind :{}: {e}", local_addr.port()))?; sock.bind(&bind_addr.into())
.map_err(|e| format!("bind :{}: {e}", local_addr.port()))?;
let std_sock: StdUdpSocket = sock.into(); let std_sock: StdUdpSocket = sock.into();
for addr in &tickle_addrs { for addr in &tickle_addrs {
let _ = std_sock.send_to(&[0u8; 1], addr); let _ = std_sock.send_to(&[0u8; 1], addr);
@@ -469,13 +477,8 @@ pub async fn race(
candidate_idx = idx, candidate_idx = idx,
"dual_path: dialing candidate" "dual_path: dialing candidate"
); );
let result = wzp_transport::connect( let result =
&ep, wzp_transport::connect(&ep, candidate, &sni, client_cfg).await;
candidate,
&sni,
client_cfg,
)
.await;
let elapsed = start.elapsed().as_millis() as u32; let elapsed = start.elapsed().as_millis() as u32;
let diag_result = match &result { let diag_result = match &result {
Ok(_) => "ok".to_string(), Ok(_) => "ok".to_string(),
@@ -604,9 +607,7 @@ pub async fn race(
"dual_path: racing direct vs relay" "dual_path: racing direct vs relay"
); );
let mut direct_task = tokio::spawn( let mut direct_task = tokio::spawn(tokio::time::timeout(Duration::from_secs(4), direct_fut));
tokio::time::timeout(Duration::from_secs(4), direct_fut),
);
let mut relay_task = tokio::spawn(async move { let mut relay_task = tokio::spawn(async move {
// Keep the 500ms head start so direct has a chance // Keep the 500ms head start so direct has a chance
tokio::time::sleep(Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
@@ -695,8 +696,12 @@ pub async fn race(
// If it doesn't, we still proceed with just the winner. // If it doesn't, we still proceed with just the winner.
if direct_result.is_none() { if direct_result.is_none() {
match tokio::time::timeout(Duration::from_secs(1), direct_task).await { match tokio::time::timeout(Duration::from_secs(1), direct_task).await {
Ok(Ok(Ok(Ok(t)))) => { direct_result = Some(Ok(t)); } Ok(Ok(Ok(Ok(t)))) => {
Ok(Ok(Ok(Err(e)))) => { direct_result = Some(Err(anyhow::anyhow!("{e}"))); } direct_result = Some(Ok(t));
}
Ok(Ok(Ok(Err(e)))) => {
direct_result = Some(Err(anyhow::anyhow!("{e}")));
}
_ => { _ => {
direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period"))); direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period")));
// Fill timeout diags for candidates that never reported. // Fill timeout diags for candidates that never reported.
@@ -719,9 +724,15 @@ pub async fn race(
} }
if relay_result.is_none() { if relay_result.is_none() {
match tokio::time::timeout(Duration::from_secs(1), relay_task).await { match tokio::time::timeout(Duration::from_secs(1), relay_task).await {
Ok(Ok(Ok(Ok(t)))) => { relay_result = Some(Ok(t)); } Ok(Ok(Ok(Ok(t)))) => {
Ok(Ok(Ok(Err(e)))) => { relay_result = Some(Err(anyhow::anyhow!("{e}"))); } relay_result = Some(Ok(t));
_ => { relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period"))); } }
Ok(Ok(Ok(Err(e)))) => {
relay_result = Some(Err(anyhow::anyhow!("{e}")));
}
_ => {
relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period")));
}
} }
} }
@@ -736,22 +747,21 @@ pub async fn race(
); );
if !direct_ok && !relay_ok { if !direct_ok && !relay_ok {
return Err(anyhow::anyhow!("both paths failed: no media transport available")); return Err(anyhow::anyhow!(
"both paths failed: no media transport available"
));
} }
let _ = (direct_ep, relay_ep, ipv6_endpoint); let _ = (direct_ep, relay_ep, ipv6_endpoint);
let candidate_diags = diags_collector.lock() let candidate_diags = diags_collector
.lock()
.map(|d| d.clone()) .map(|d| d.clone())
.unwrap_or_default(); .unwrap_or_default();
Ok(RaceResult { Ok(RaceResult {
direct_transport: direct_result direct_transport: direct_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
.and_then(|r| r.ok()) relay_transport: relay_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
.map(|t| Arc::new(t)),
relay_transport: relay_result
.and_then(|r| r.ok())
.map(|t| Arc::new(t)),
local_winner, local_winner,
candidate_diags, candidate_diags,
}) })
@@ -777,7 +787,10 @@ mod tests {
assert_eq!(order.len(), 4); assert_eq!(order.len(), 4);
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap()); assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
assert_eq!(order[1], "10.0.0.5:4433".parse::<SocketAddr>().unwrap()); assert_eq!(order[1], "10.0.0.5:4433".parse::<SocketAddr>().unwrap());
assert_eq!(order[2], "198.51.100.42:12345".parse::<SocketAddr>().unwrap()); assert_eq!(
order[2],
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
);
assert_eq!(order[3], "203.0.113.5:4433".parse::<SocketAddr>().unwrap()); assert_eq!(order[3], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
} }
@@ -805,7 +818,10 @@ mod tests {
let order = candidates.dial_order(); let order = candidates.dial_order();
assert_eq!(order.len(), 1); assert_eq!(order.len(), 1);
assert_eq!(order[0], "198.51.100.42:12345".parse::<SocketAddr>().unwrap()); assert_eq!(
order[0],
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
);
} }
#[test] #[test]

View File

@@ -166,7 +166,7 @@ pub async fn run_echo_test(
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await { match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
Ok(Ok(Some(pkt))) => { Ok(Ok(Some(pkt))) => {
total_packets_received += 1; total_packets_received += 1;
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
if !is_repair { if !is_repair {
if let Some(n) = decoder.decode_next(&mut pcm_buf) { if let Some(n) = decoder.decode_next(&mut pcm_buf) {
@@ -184,7 +184,8 @@ pub async fn run_echo_test(
let time_offset = start.elapsed().as_secs_f64(); let time_offset = start.elapsed().as_secs_f64();
// Compare sent vs received for this window // Compare sent vs received for this window
let sent_start = (window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize; let sent_start =
(window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES); let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES);
let sent_window = if sent_end <= sent_pcm.len() { let sent_window = if sent_end <= sent_pcm.len() {
&sent_pcm[sent_start..sent_end] &sent_pcm[sent_start..sent_end]
@@ -192,7 +193,9 @@ pub async fn run_echo_test(
&sent_pcm[sent_start..] &sent_pcm[sent_start..]
}; };
let recv_start = recv_pcm.len().saturating_sub(window_frames_received as usize * FRAME_SAMPLES); let recv_start = recv_pcm
.len()
.saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
let recv_window = &recv_pcm[recv_start..]; let recv_window = &recv_pcm[recv_start..];
let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0); let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0);
@@ -256,7 +259,7 @@ pub async fn run_echo_test(
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await { match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
Ok(Ok(Some(pkt))) => { Ok(Ok(Some(pkt))) => {
total_packets_received += 1; total_packets_received += 1;
let is_repair = pkt.header.is_repair; let is_repair = pkt.header.is_repair();
decoder.ingest(pkt); decoder.ingest(pkt);
if !is_repair { if !is_repair {
decoder.decode_next(&mut pcm_buf); decoder.decode_next(&mut pcm_buf);
@@ -310,8 +313,14 @@ pub fn print_report(result: &EchoTestResult) {
let status = if w.is_silent { " !" } else { " " }; let status = if w.is_silent { " !" } else { " " };
println!( println!(
"{:>3}{}{:>5.1}s │ {:>4}{:>4}{:>5.1}% │ {:>5.1}{:.3}", "{:>3}{}{:>5.1}s │ {:>4}{:>4}{:>5.1}% │ {:>5.1}{:.3}",
w.index, status, w.time_offset_secs, w.frames_sent, w.frames_received, w.index,
w.loss_pct, w.snr_db, w.correlation status,
w.time_offset_secs,
w.frames_sent,
w.frames_received,
w.loss_pct,
w.snr_db,
w.correlation
); );
} }
println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘"); println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘");
@@ -321,18 +330,28 @@ pub fn print_report(result: &EchoTestResult) {
let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec(); let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec();
let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec(); let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec();
let avg_loss_first = first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32; let avg_loss_first =
let avg_loss_second = second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32; first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
let avg_corr_first = first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32; let avg_loss_second =
let avg_corr_second = second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32; second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
let avg_corr_first =
first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
let avg_corr_second =
second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
println!(); println!();
if avg_loss_second > avg_loss_first + 5.0 { if avg_loss_second > avg_loss_first + 5.0 {
println!("WARNING: Quality degradation detected!"); println!("WARNING: Quality degradation detected!");
println!(" Loss increased from {:.1}% to {:.1}% over time", avg_loss_first, avg_loss_second); println!(
" Loss increased from {:.1}% to {:.1}% over time",
avg_loss_first, avg_loss_second
);
} }
if avg_corr_second < avg_corr_first - 0.1 { if avg_corr_second < avg_corr_first - 0.1 {
println!("WARNING: Signal correlation dropped from {:.3} to {:.3}", avg_corr_first, avg_corr_second); println!(
"WARNING: Signal correlation dropped from {:.3} to {:.3}",
avg_corr_first, avg_corr_second
);
} }
if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 { if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 {
println!("Quality is STABLE over the test duration."); println!("Quality is STABLE over the test duration.");

View File

@@ -0,0 +1,213 @@
//! `EncryptingTransport` — wraps any `MediaTransport` with a `CryptoSession`.
//!
//! All outbound `send_media` calls encrypt the payload before handing off to
//! the inner transport; all inbound `recv_media` calls decrypt after receiving.
//! Signal, quality, and close are forwarded unchanged.
//!
//! The quality report travels in plaintext so the relay can make QoS decisions
//! without being able to decrypt media content.
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes;
use wzp_proto::{
CryptoSession, MediaHeader, MediaPacket, MediaTransport, PathQuality, SignalMessage,
TransportError,
};
/// Wraps a `MediaTransport` and applies AEAD encryption/decryption to media payloads.
pub struct EncryptingTransport {
inner: Arc<dyn MediaTransport>,
session: Mutex<Box<dyn CryptoSession>>,
}
impl EncryptingTransport {
pub fn new(inner: Arc<dyn MediaTransport>, session: Box<dyn CryptoSession>) -> Self {
Self {
inner,
session: Mutex::new(session),
}
}
}
#[async_trait]
impl MediaTransport for EncryptingTransport {
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
packet.header.write_to(&mut header_bytes);
let mut ciphertext = Vec::new();
self.session
.lock()
.unwrap()
.encrypt(&header_bytes, &packet.payload, &mut ciphertext)
.map_err(|e| TransportError::Internal(format!("encrypt: {e}")))?;
let encrypted = MediaPacket {
header: packet.header,
payload: Bytes::from(ciphertext),
quality_report: packet.quality_report.clone(),
};
self.inner.send_media(&encrypted).await
}
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
let packet = match self.inner.recv_media().await? {
Some(p) => p,
None => return Ok(None),
};
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
packet.header.write_to(&mut header_bytes);
let mut plaintext = Vec::new();
self.session
.lock()
.unwrap()
.decrypt(&header_bytes, &packet.payload, &mut plaintext)
.map_err(|e| TransportError::Internal(format!("decrypt: {e}")))?;
Ok(Some(MediaPacket {
header: packet.header,
payload: Bytes::from(plaintext),
quality_report: packet.quality_report,
}))
}
async fn send_signal(&self, msg: &SignalMessage) -> Result<(), TransportError> {
self.inner.send_signal(msg).await
}
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
self.inner.recv_signal().await
}
fn path_quality(&self) -> PathQuality {
self.inner.path_quality()
}
async fn close(&self) -> Result<(), TransportError> {
self.inner.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex as StdMutex;
use wzp_crypto::ChaChaSession;
use wzp_proto::{CodecId, MediaType};
struct LoopbackTransport {
sent: StdMutex<Vec<MediaPacket>>,
}
impl LoopbackTransport {
fn new() -> Arc<Self> {
Arc::new(Self {
sent: StdMutex::new(Vec::new()),
})
}
fn take_sent(&self) -> Vec<MediaPacket> {
self.sent.lock().unwrap().drain(..).collect()
}
}
#[async_trait]
impl MediaTransport for LoopbackTransport {
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
self.sent.lock().unwrap().push(packet.clone());
Ok(())
}
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
Ok(None)
}
async fn send_signal(&self, _msg: &SignalMessage) -> Result<(), TransportError> {
Ok(())
}
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
Ok(None)
}
fn path_quality(&self) -> PathQuality {
PathQuality::default()
}
async fn close(&self) -> Result<(), TransportError> {
Ok(())
}
}
fn make_header(seq: u32) -> MediaHeader {
MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq,
timestamp: seq * 20,
fec_block: 0,
}
}
#[tokio::test]
async fn payload_is_encrypted_on_wire() {
let key = [0x42u8; 32];
let session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
let loopback = LoopbackTransport::new();
let enc = EncryptingTransport::new(loopback.clone(), session);
let header = make_header(1);
let plaintext = b"secret audio frame";
let pkt = MediaPacket {
header,
payload: Bytes::from_static(plaintext),
quality_report: None,
};
enc.send_media(&pkt).await.unwrap();
let sent = loopback.take_sent();
assert_eq!(sent.len(), 1);
assert_eq!(sent[0].header, header, "header must be preserved");
assert_ne!(
sent[0].payload.as_ref(),
plaintext.as_ref(),
"plaintext must not appear on wire"
);
// Ciphertext is longer by exactly the AEAD tag (16 bytes)
assert_eq!(sent[0].payload.len(), plaintext.len() + 16);
}
#[tokio::test]
async fn encrypt_then_decrypt_roundtrip() {
let key = [0x42u8; 32];
let send_session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
let mut recv_session = ChaChaSession::new(key);
let loopback = LoopbackTransport::new();
let enc = EncryptingTransport::new(loopback.clone(), send_session);
let header = make_header(5);
let plaintext = b"hello encrypted world";
let pkt = MediaPacket {
header,
payload: Bytes::from_static(plaintext),
quality_report: None,
};
enc.send_media(&pkt).await.unwrap();
let sent = loopback.take_sent();
let wire_pkt = &sent[0];
let mut header_bytes = Vec::new();
header.write_to(&mut header_bytes);
let mut decrypted = Vec::new();
recv_session
.decrypt(&header_bytes, &wire_pkt.payload, &mut decrypted)
.expect("decrypt should succeed with matching key");
assert_eq!(&decrypted[..], plaintext);
}
}

View File

@@ -99,14 +99,15 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry) SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry)
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer, SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
SignalMessage::AuthToken { .. } => CallSignalType::Offer, SignalMessage::AuthToken { .. } => CallSignalType::Offer,
SignalMessage::Hold => CallSignalType::Hold, SignalMessage::Hold { .. } => CallSignalType::Hold,
SignalMessage::Unhold => CallSignalType::Unhold, SignalMessage::Unhold { .. } => CallSignalType::Unhold,
SignalMessage::Mute => CallSignalType::Mute, SignalMessage::Mute { .. } => CallSignalType::Mute,
SignalMessage::Unmute => CallSignalType::Unmute, SignalMessage::Unmute { .. } => CallSignalType::Unmute,
SignalMessage::Transfer { .. } => CallSignalType::Transfer, SignalMessage::Transfer { .. } => CallSignalType::Transfer,
SignalMessage::TransferAck => CallSignalType::Offer, // reuse SignalMessage::TransferAck { .. } => CallSignalType::Offer, // reuse
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
SignalMessage::TransportFeedback { .. } => CallSignalType::Offer, // reuse (BWE)
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
@@ -118,14 +119,14 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer, SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer,
SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only
SignalMessage::CallRinging { .. } => CallSignalType::Ringing, SignalMessage::CallRinging { .. } => CallSignalType::Ringing,
SignalMessage::RegisterPresence { .. } SignalMessage::RegisterPresence { .. } | SignalMessage::RegisterPresenceAck { .. } => {
| SignalMessage::RegisterPresenceAck { .. } => CallSignalType::Offer, // relay-only CallSignalType::Offer
} // relay-only
// NAT reflection is a client↔relay control exchange that // NAT reflection is a client↔relay control exchange that
// never crosses the featherChat bridge — if it ever reaches // never crosses the featherChat bridge — if it ever reaches
// this mapper something is wrong, but we still have to give // this mapper something is wrong, but we still have to give
// an answer. "Offer" is the generic catch-all. // an answer. "Offer" is the generic catch-all.
SignalMessage::Reflect SignalMessage::Reflect | SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
| SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
// Phase 4 cross-relay forwarding envelope — strictly a // Phase 4 cross-relay forwarding envelope — strictly a
// relay-to-relay message, never rides the featherChat // relay-to-relay message, never rides the featherChat
// bridge. Catch-all mapping for completeness. // bridge. Catch-all mapping for completeness.
@@ -140,6 +141,9 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
| SignalMessage::QualityCapability { .. } => CallSignalType::Offer, // quality negotiation | SignalMessage::QualityCapability { .. } => CallSignalType::Offer, // quality negotiation
SignalMessage::PresenceList { .. } => CallSignalType::Offer, // lobby presence SignalMessage::PresenceList { .. } => CallSignalType::Offer, // lobby presence
SignalMessage::QualityDirective { .. } => CallSignalType::Offer, // relay-initiated SignalMessage::QualityDirective { .. } => CallSignalType::Offer, // relay-initiated
SignalMessage::Nack { .. }
| SignalMessage::PictureLossIndication { .. }
| SignalMessage::SetPriorityMode { .. } => CallSignalType::Offer, // relay-initiated (video loss recovery)
} }
} }
@@ -147,15 +151,20 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
mod tests { mod tests {
use super::*; use super::*;
use wzp_proto::QualityProfile; use wzp_proto::QualityProfile;
use wzp_proto::default_signal_version;
#[test] #[test]
fn payload_roundtrip() { fn payload_roundtrip() {
let signal = SignalMessage::CallOffer { let signal = SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub: [1u8; 32], identity_pub: [1u8; 32],
ephemeral_pub: [2u8; 32], ephemeral_pub: [2u8; 32],
signature: vec![3u8; 64], signature: vec![3u8; 64],
supported_profiles: vec![QualityProfile::GOOD], supported_profiles: vec![QualityProfile::GOOD],
alias: None, alias: None,
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}; };
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom")); let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
@@ -169,29 +178,53 @@ mod tests {
#[test] #[test]
fn signal_type_mapping() { fn signal_type_mapping() {
let offer = SignalMessage::CallOffer { let offer = SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub: [0; 32], identity_pub: [0; 32],
ephemeral_pub: [0; 32], ephemeral_pub: [0; 32],
signature: vec![], signature: vec![],
supported_profiles: vec![], supported_profiles: vec![],
alias: None, alias: None,
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}; };
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer)); assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
let hangup = SignalMessage::Hangup { let hangup = SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}; };
assert!(matches!(signal_to_call_type(&hangup), CallSignalType::Hangup)); assert!(matches!(
signal_to_call_type(&hangup),
CallSignalType::Hangup
));
assert!(matches!(signal_to_call_type(&SignalMessage::Hold), CallSignalType::Hold)); assert!(matches!(
assert!(matches!(signal_to_call_type(&SignalMessage::Unhold), CallSignalType::Unhold)); signal_to_call_type(&SignalMessage::Hold { version: default_signal_version() }),
assert!(matches!(signal_to_call_type(&SignalMessage::Mute), CallSignalType::Mute)); CallSignalType::Hold
assert!(matches!(signal_to_call_type(&SignalMessage::Unmute), CallSignalType::Unmute)); ));
assert!(matches!(
signal_to_call_type(&SignalMessage::Unhold { version: default_signal_version() }),
CallSignalType::Unhold
));
assert!(matches!(
signal_to_call_type(&SignalMessage::Mute { version: default_signal_version() }),
CallSignalType::Mute
));
assert!(matches!(
signal_to_call_type(&SignalMessage::Unmute { version: default_signal_version() }),
CallSignalType::Unmute
));
let transfer = SignalMessage::Transfer { let transfer = SignalMessage::Transfer {
version: default_signal_version(),
target_fingerprint: "abc".to_string(), target_fingerprint: "abc".to_string(),
relay_addr: None, relay_addr: None,
}; };
assert!(matches!(signal_to_call_type(&transfer), CallSignalType::Transfer)); assert!(matches!(
signal_to_call_type(&transfer),
CallSignalType::Transfer
));
} }
} }

View File

@@ -4,7 +4,60 @@
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`. //! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange}; use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage}; use wzp_proto::{
CodecId, HangupReason, MediaTransport, QualityProfile, SignalMessage, default_signal_version,
};
/// Result of a successful client-side handshake.
pub struct HandshakeResult {
pub session: Box<dyn CryptoSession>,
/// Video codec agreed with the relay. `None` if peer is audio-only.
pub video_codec: Option<CodecId>,
}
/// Errors that can occur during the client-side cryptographic handshake.
#[derive(Debug)]
pub enum HandshakeError {
ConnectionClosed,
ProtocolVersionMismatch { server_supported: Vec<u8> },
UnexpectedSignal(&'static str),
SignatureVerificationFailed,
KeyDerivation(String),
Transport(wzp_proto::TransportError),
}
impl std::fmt::Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConnectionClosed => write!(f, "connection closed before receiving CallAnswer"),
Self::ProtocolVersionMismatch { server_supported } => {
write!(
f,
"protocol version mismatch: server supports {server_supported:?}"
)
}
Self::UnexpectedSignal(expected) => write!(f, "expected CallAnswer, got {expected}"),
Self::SignatureVerificationFailed => write!(f, "callee signature verification failed"),
Self::KeyDerivation(msg) => write!(f, "key derivation failed: {msg}"),
Self::Transport(e) => write!(f, "transport error: {e}"),
}
}
}
impl std::error::Error for HandshakeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Transport(e) => Some(e),
_ => None,
}
}
}
impl From<wzp_proto::TransportError> for HandshakeError {
fn from(e: wzp_proto::TransportError) -> Self {
Self::Transport(e)
}
}
/// Perform the client (caller) side of the cryptographic handshake. /// Perform the client (caller) side of the cryptographic handshake.
/// ///
@@ -18,7 +71,7 @@ pub async fn perform_handshake(
transport: &dyn MediaTransport, transport: &dyn MediaTransport,
seed: &[u8; 32], seed: &[u8; 32],
alias: Option<&str>, alias: Option<&str>,
) -> Result<Box<dyn CryptoSession>, anyhow::Error> { ) -> Result<HandshakeResult, HandshakeError> {
// 1. Create key exchange from identity seed // 1. Create key exchange from identity seed
let mut kx = WarzoneKeyExchange::from_identity_seed(seed); let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
let identity_pub = kx.identity_public_key(); let identity_pub = kx.identity_public_key();
@@ -34,6 +87,7 @@ pub async fn perform_handshake(
// 4. Send CallOffer // 4. Send CallOffer
let offer = SignalMessage::CallOffer { let offer = SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature, signature,
@@ -46,28 +100,43 @@ pub async fn perform_handshake(
QualityProfile::CATASTROPHIC, QualityProfile::CATASTROPHIC,
], ],
alias: alias.map(|s| s.to_string()), alias: alias.map(|s| s.to_string()),
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![CodecId::Av1Main, CodecId::H264Baseline, CodecId::H265Main],
}; };
transport.send_signal(&offer).await?; transport
.send_signal(&offer)
.await
.map_err(HandshakeError::Transport)?;
// 5. Wait for CallAnswer // 5. Wait for CallAnswer — 10s timeout guards against relay not responding.
let answer = transport let answer = tokio::time::timeout(
.recv_signal() std::time::Duration::from_secs(10),
.await? transport.recv_signal(),
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallAnswer"))?; )
.await
.map_err(|_| HandshakeError::Transport(wzp_proto::TransportError::Timeout { ms: 10_000 }))?
.map_err(HandshakeError::Transport)?
.ok_or(HandshakeError::ConnectionClosed)?;
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) = match answer let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile, video_codec) =
{ match answer {
SignalMessage::CallAnswer { SignalMessage::CallAnswer {
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature, signature,
chosen_profile, chosen_profile,
} => (identity_pub, ephemeral_pub, signature, chosen_profile), video_codec,
other => { ..
return Err(anyhow::anyhow!( } => (identity_pub, ephemeral_pub, signature, chosen_profile, video_codec),
"expected CallAnswer, got {:?}", SignalMessage::Hangup {
std::mem::discriminant(&other) reason: HangupReason::ProtocolVersionMismatch { server_supported },
)) ..
} => {
return Err(HandshakeError::ProtocolVersionMismatch { server_supported });
}
_ => {
return Err(HandshakeError::UnexpectedSignal("CallAnswer"));
} }
}; };
@@ -76,13 +145,15 @@ pub async fn perform_handshake(
verify_data.extend_from_slice(&callee_ephemeral_pub); verify_data.extend_from_slice(&callee_ephemeral_pub);
verify_data.extend_from_slice(b"call-answer"); verify_data.extend_from_slice(b"call-answer");
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) { if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
return Err(anyhow::anyhow!("callee signature verification failed")); return Err(HandshakeError::SignatureVerificationFailed);
} }
// 7. Derive session // 7. Derive session
let session = kx.derive_session(&callee_ephemeral_pub)?; let session = kx
.derive_session(&callee_ephemeral_pub)
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
Ok(session) Ok(HandshakeResult { session, video_codec })
} }
#[cfg(test)] #[cfg(test)]
@@ -104,4 +175,30 @@ mod tests {
&sig, &sig,
)); ));
} }
#[test]
fn handshake_result_carries_video_codec() {
// Verify that HandshakeResult has both fields accessible and that
// None is the correct default for audio-only peers.
let mut kx = WarzoneKeyExchange::from_identity_seed(&[0x55; 32]);
kx.generate_ephemeral();
let session = kx.derive_session(&[0u8; 32]).unwrap();
let hs = HandshakeResult { session, video_codec: None };
assert!(hs.video_codec.is_none());
let mut kx2 = WarzoneKeyExchange::from_identity_seed(&[0x66; 32]);
kx2.generate_ephemeral();
let session2 = kx2.derive_session(&[0u8; 32]).unwrap();
let hs2 = HandshakeResult { session: session2, video_codec: Some(CodecId::Av1Main) };
assert_eq!(hs2.video_codec, Some(CodecId::Av1Main));
}
#[test]
fn offer_contains_three_video_codecs() {
// The offer sent in perform_handshake always includes the three codecs
// declared in order: AV1 > H264 > H265. Verify via the const list.
let offered = vec![CodecId::Av1Main, CodecId::H264Baseline, CodecId::H265Main];
assert_eq!(offered.len(), 3);
assert_eq!(offered[0], CodecId::Av1Main, "AV1 must be preferred");
}
} }

View File

@@ -17,7 +17,7 @@ use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration; use std::time::Duration;
use wzp_proto::SignalMessage; use wzp_proto::{SignalMessage, default_signal_version};
use crate::dual_path::PeerCandidates; use crate::dual_path::PeerCandidates;
use crate::portmap; use crate::portmap;
@@ -106,14 +106,9 @@ impl IceAgent {
); );
let reflexive = stun_result.ok().and_then(|r| r.ok()); let reflexive = stun_result.ok().and_then(|r| r.ok());
let mapped = portmap_result let mapped = portmap_result.ok().flatten().map(|m| m.external_addr);
.ok() let local =
.flatten() reflect::local_host_candidates(self.config.local_v4_port, self.config.local_v6_port);
.map(|m| m.external_addr);
let local = reflect::local_host_candidates(
self.config.local_v4_port,
self.config.local_v6_port,
);
tracing::info!( tracing::info!(
generation, generation,
@@ -138,6 +133,7 @@ impl IceAgent {
let candidates = self.gather().await; let candidates = self.gather().await;
let update = SignalMessage::CandidateUpdate { let update = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: self.call_id.clone(), call_id: self.call_id.clone(),
reflexive_addr: candidates.reflexive.map(|a| a.to_string()), reflexive_addr: candidates.reflexive.map(|a| a.to_string()),
local_addrs: candidates.local.iter().map(|a| a.to_string()).collect(), local_addrs: candidates.local.iter().map(|a| a.to_string()).collect(),
@@ -151,10 +147,7 @@ impl IceAgent {
/// Process a peer's candidate update. Returns `Some(PeerCandidates)` /// Process a peer's candidate update. Returns `Some(PeerCandidates)`
/// if the update is newer than the last-seen generation, `None` /// if the update is newer than the last-seen generation, `None`
/// if it's stale. /// if it's stale.
pub fn apply_peer_update( pub fn apply_peer_update(&self, update: &SignalMessage) -> Option<PeerCandidates> {
&self,
update: &SignalMessage,
) -> Option<PeerCandidates> {
let (reflexive_addr, local_addrs, mapped_addr, generation) = match update { let (reflexive_addr, local_addrs, mapped_addr, generation) = match update {
SignalMessage::CandidateUpdate { SignalMessage::CandidateUpdate {
reflexive_addr, reflexive_addr,
@@ -177,16 +170,9 @@ impl IceAgent {
return None; return None;
} }
let reflexive = reflexive_addr let reflexive = reflexive_addr.as_deref().and_then(|s| s.parse().ok());
.as_deref() let local: Vec<SocketAddr> = local_addrs.iter().filter_map(|s| s.parse().ok()).collect();
.and_then(|s| s.parse().ok()); let mapped = mapped_addr.as_deref().and_then(|s| s.parse().ok());
let local: Vec<SocketAddr> = local_addrs
.iter()
.filter_map(|s| s.parse().ok())
.collect();
let mapped = mapped_addr
.as_deref()
.and_then(|s| s.parse().ok());
tracing::info!( tracing::info!(
generation, generation,
@@ -221,6 +207,7 @@ mod tests {
// First update (gen=1) should succeed. // First update (gen=1) should succeed.
let update1 = SignalMessage::CandidateUpdate { let update1 = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test-call".into(), call_id: "test-call".into(),
reflexive_addr: Some("203.0.113.5:4433".into()), reflexive_addr: Some("203.0.113.5:4433".into()),
local_addrs: vec!["192.168.1.10:4433".into()], local_addrs: vec!["192.168.1.10:4433".into()],
@@ -238,6 +225,7 @@ mod tests {
// Same generation (gen=1) should be rejected. // Same generation (gen=1) should be rejected.
let update1b = SignalMessage::CandidateUpdate { let update1b = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test-call".into(), call_id: "test-call".into(),
reflexive_addr: Some("198.51.100.9:4433".into()), reflexive_addr: Some("198.51.100.9:4433".into()),
local_addrs: vec![], local_addrs: vec![],
@@ -248,6 +236,7 @@ mod tests {
// Older generation (gen=0) should be rejected. // Older generation (gen=0) should be rejected.
let update0 = SignalMessage::CandidateUpdate { let update0 = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test-call".into(), call_id: "test-call".into(),
reflexive_addr: Some("10.0.0.1:4433".into()), reflexive_addr: Some("10.0.0.1:4433".into()),
local_addrs: vec![], local_addrs: vec![],
@@ -258,6 +247,7 @@ mod tests {
// Newer generation (gen=2) should succeed. // Newer generation (gen=2) should succeed.
let update2 = SignalMessage::CandidateUpdate { let update2 = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test-call".into(), call_id: "test-call".into(),
reflexive_addr: Some("198.51.100.9:5555".into()), reflexive_addr: Some("198.51.100.9:5555".into()),
local_addrs: vec![], local_addrs: vec![],
@@ -302,12 +292,10 @@ mod tests {
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default()); let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
let update = SignalMessage::CandidateUpdate { let update = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test-call".into(), call_id: "test-call".into(),
reflexive_addr: Some("203.0.113.5:4433".into()), reflexive_addr: Some("203.0.113.5:4433".into()),
local_addrs: vec![ local_addrs: vec!["192.168.1.10:4433".into(), "10.0.0.5:4433".into()],
"192.168.1.10:4433".into(),
"10.0.0.5:4433".into(),
],
mapped_addr: Some("198.51.100.42:12345".into()), mapped_addr: Some("198.51.100.42:12345".into()),
generation: 1, generation: 1,
}; };
@@ -333,6 +321,7 @@ mod tests {
let agent = IceAgent::new("test".into(), IceAgentConfig::default()); let agent = IceAgent::new("test".into(), IceAgentConfig::default());
let update = SignalMessage::CandidateUpdate { let update = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test".into(), call_id: "test".into(),
reflexive_addr: None, reflexive_addr: None,
local_addrs: vec![], local_addrs: vec![],
@@ -351,6 +340,7 @@ mod tests {
let agent = IceAgent::new("test".into(), IceAgentConfig::default()); let agent = IceAgent::new("test".into(), IceAgentConfig::default());
let update = SignalMessage::CandidateUpdate { let update = SignalMessage::CandidateUpdate {
version: default_signal_version(),
call_id: "test".into(), call_id: "test".into(),
reflexive_addr: Some("not-an-addr".into()), reflexive_addr: Some("not-an-addr".into()),
local_addrs: vec![ local_addrs: vec![
@@ -382,7 +372,9 @@ mod tests {
async fn gather_returns_candidates_even_with_no_stun() { async fn gather_returns_candidates_even_with_no_stun() {
// With default config (port 0 = no portmap, STUN will timeout // With default config (port 0 = no portmap, STUN will timeout
// quickly on loopback), gather should still return host candidates. // quickly on loopback), gather should still return host candidates.
let agent = IceAgent::new("test".into(), IceAgentConfig { let agent = IceAgent::new(
"test".into(),
IceAgentConfig {
stun_config: stun::StunConfig { stun_config: stun::StunConfig {
servers: vec![], // no servers = quick failure servers: vec![], // no servers = quick failure
timeout: Duration::from_millis(100), timeout: Duration::from_millis(100),
@@ -391,7 +383,8 @@ mod tests {
gather_timeout: Duration::from_millis(200), gather_timeout: Duration::from_millis(200),
local_v4_port: 12345, local_v4_port: 12345,
local_v6_port: None, local_v6_port: None,
}); },
);
let candidates = agent.gather().await; let candidates = agent.gather().await;
assert_eq!(candidates.generation, 0); assert_eq!(candidates.generation, 0);
@@ -405,7 +398,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn re_gather_produces_signal_message() { async fn re_gather_produces_signal_message() {
let agent = IceAgent::new("call-42".into(), IceAgentConfig { let agent = IceAgent::new(
"call-42".into(),
IceAgentConfig {
stun_config: stun::StunConfig { stun_config: stun::StunConfig {
servers: vec![], servers: vec![],
timeout: Duration::from_millis(50), timeout: Duration::from_millis(50),
@@ -414,7 +409,8 @@ mod tests {
gather_timeout: Duration::from_millis(100), gather_timeout: Duration::from_millis(100),
local_v4_port: 4433, local_v4_port: 4433,
local_v6_port: None, local_v6_port: None,
}); },
);
let (candidates, signal) = agent.re_gather().await; let (candidates, signal) = agent.re_gather().await;
assert_eq!(candidates.generation, 0); assert_eq!(candidates.generation, 0);

View File

@@ -27,15 +27,16 @@ pub mod audio_wasapi;
#[cfg(all(feature = "linux-aec", target_os = "linux"))] #[cfg(all(feature = "linux-aec", target_os = "linux"))]
pub mod audio_linux_aec; pub mod audio_linux_aec;
pub mod bench; pub mod bench;
pub mod birthday;
pub mod call; pub mod call;
pub mod encrypted_transport;
pub mod drift_test; pub mod drift_test;
pub mod dual_path;
pub mod echo_test; pub mod echo_test;
pub mod featherchat; pub mod featherchat;
pub mod handshake; pub mod handshake;
pub mod dual_path;
pub mod metrics;
pub mod birthday;
pub mod ice_agent; pub mod ice_agent;
pub mod metrics;
pub mod netcheck; pub mod netcheck;
pub mod portmap; pub mod portmap;
pub mod reflect; pub mod reflect;

View File

@@ -178,7 +178,10 @@ mod tests {
// Immediate second write should be skipped (60s interval). // Immediate second write should be skipped (60s interval).
let second = writer.maybe_write(&snap).unwrap(); let second = writer.maybe_write(&snap).unwrap();
assert!(!second, "second write should be skipped — interval not elapsed"); assert!(
!second,
"second write should be skipped — interval not elapsed"
);
// Clean up. // Clean up.
let _ = std::fs::remove_file(&path); let _ = std::fs::remove_file(&path);

View File

@@ -112,22 +112,30 @@ pub async fn run_netcheck(config: &NetcheckConfig) -> NetcheckReport {
let ipv6_fut = test_ipv6(config.test_ipv6, config.timeout); let ipv6_fut = test_ipv6(config.test_ipv6, config.timeout);
let port_alloc_fut = stun::detect_port_allocation(&config.stun_config); let port_alloc_fut = stun::detect_port_allocation(&config.stun_config);
let (stun_probes, relay_latencies, portmap_result, gateway_result, ipv6_reachable, port_alloc_result) = let (
tokio::join!(stun_fut, relay_fut, portmap_fut, gateway_result_fut(gateway_fut), ipv6_fut, port_alloc_fut); stun_probes,
relay_latencies,
portmap_result,
gateway_result,
ipv6_reachable,
port_alloc_result,
) = tokio::join!(
stun_fut,
relay_fut,
portmap_fut,
gateway_result_fut(gateway_fut),
ipv6_fut,
port_alloc_fut
);
// Classify NAT from STUN probes. // Classify NAT from STUN probes.
let (nat_type, consensus_addr) = reflect::classify_nat(&stun_probes); let (nat_type, consensus_addr) = reflect::classify_nat(&stun_probes);
// Determine STUN latency (first successful probe). // Determine STUN latency (first successful probe).
let stun_latency_ms = stun_probes let stun_latency_ms = stun_probes.iter().filter_map(|p| p.latency_ms).min();
.iter()
.filter_map(|p| p.latency_ms)
.min();
// IPv4 reachable if any STUN probe succeeded. // IPv4 reachable if any STUN probe succeeded.
let ipv4_reachable = stun_probes let ipv4_reachable = stun_probes.iter().any(|p| p.observed_addr.is_some());
.iter()
.any(|p| p.observed_addr.is_some());
// Preferred relay = lowest RTT. // Preferred relay = lowest RTT.
let preferred_relay = relay_latencies let preferred_relay = relay_latencies
@@ -176,10 +184,7 @@ pub async fn run_netcheck(config: &NetcheckConfig) -> NetcheckReport {
} }
/// Probe relay latencies via reflect. /// Probe relay latencies via reflect.
async fn probe_relays( async fn probe_relays(relays: &[(String, SocketAddr)], timeout: Duration) -> Vec<RelayLatency> {
relays: &[(String, SocketAddr)],
timeout: Duration,
) -> Vec<RelayLatency> {
if relays.is_empty() { if relays.is_empty() {
return Vec::new(); return Vec::new();
} }
@@ -223,10 +228,7 @@ async fn probe_relays(
} }
/// Attempt port mapping and return the mapping if successful. /// Attempt port mapping and return the mapping if successful.
async fn probe_portmap( async fn probe_portmap(enabled: bool, local_port: u16) -> Option<portmap::PortMapping> {
enabled: bool,
local_port: u16,
) -> Option<portmap::PortMapping> {
if !enabled || local_port == 0 { if !enabled || local_port == 0 {
return None; return None;
} }
@@ -251,7 +253,9 @@ async fn test_ipv6(enabled: bool, timeout: Duration) -> bool {
let sock = tokio::net::UdpSocket::bind("[::]:0").await.ok()?; let sock = tokio::net::UdpSocket::bind("[::]:0").await.ok()?;
// Try Google's IPv6 STUN — if DNS resolves to an AAAA record // Try Google's IPv6 STUN — if DNS resolves to an AAAA record
// and we can send a packet, IPv6 is working. // and we can send a packet, IPv6 is working.
let addr = stun::resolve_stun_server("stun.l.google.com:19302").await.ok()?; let addr = stun::resolve_stun_server("stun.l.google.com:19302")
.await
.ok()?;
if addr.is_ipv6() { if addr.is_ipv6() {
sock.send_to(&[0u8; 1], addr).await.ok()?; sock.send_to(&[0u8; 1], addr).await.ok()?;
Some(true) Some(true)
@@ -276,10 +280,7 @@ pub fn format_report(report: &NetcheckReport) -> String {
let mut out = String::new(); let mut out = String::new();
out.push_str(&format!("=== WarzonePhone Netcheck ===\n\n")); out.push_str(&format!("=== WarzonePhone Netcheck ===\n\n"));
out.push_str(&format!( out.push_str(&format!("NAT Type: {:?}\n", report.nat_type));
"NAT Type: {:?}\n",
report.nat_type
));
out.push_str(&format!( out.push_str(&format!(
"Reflexive Addr: {}\n", "Reflexive Addr: {}\n",
report.reflexive_addr.as_deref().unwrap_or("(unknown)") report.reflexive_addr.as_deref().unwrap_or("(unknown)")
@@ -298,15 +299,17 @@ pub fn format_report(report: &NetcheckReport) -> String {
)); ));
if let Some(ref alloc) = report.port_allocation { if let Some(ref alloc) = report.port_allocation {
out.push_str(&format!( out.push_str(&format!("Port Alloc: {alloc}\n"));
"Port Alloc: {alloc}\n"
));
} }
out.push_str(&format!("\n--- Port Mapping ---\n")); out.push_str(&format!("\n--- Port Mapping ---\n"));
out.push_str(&format!( out.push_str(&format!(
"NAT-PMP: {} PCP: {} UPnP: {}\n", "NAT-PMP: {} PCP: {} UPnP: {}\n",
if report.nat_pmp_available { "yes" } else { "no" }, if report.nat_pmp_available {
"yes"
} else {
"no"
},
if report.pcp_available { "yes" } else { "no" }, if report.pcp_available { "yes" } else { "no" },
if report.upnp_available { "yes" } else { "no" }, if report.upnp_available { "yes" } else { "no" },
)); ));
@@ -321,8 +324,13 @@ pub fn format_report(report: &NetcheckReport) -> String {
" {}{} ({}ms){}\n", " {}{} ({}ms){}\n",
p.relay_name, p.relay_name,
p.observed_addr.as_deref().unwrap_or("failed"), p.observed_addr.as_deref().unwrap_or("failed"),
p.latency_ms.map(|ms| ms.to_string()).unwrap_or_else(|| "-".into()), p.latency_ms
p.error.as_ref().map(|e| format!(" [{e}]")).unwrap_or_default(), .map(|ms| ms.to_string())
.unwrap_or_else(|| "-".into()),
p.error
.as_ref()
.map(|e| format!(" [{e}]"))
.unwrap_or_default(),
)); ));
} }
} }
@@ -334,8 +342,13 @@ pub fn format_report(report: &NetcheckReport) -> String {
" {} ({}) → {}ms{}\n", " {} ({}) → {}ms{}\n",
r.name, r.name,
r.addr, r.addr,
r.rtt_ms.map(|ms| ms.to_string()).unwrap_or_else(|| "-".into()), r.rtt_ms
r.error.as_ref().map(|e| format!(" [{e}]")).unwrap_or_default(), .map(|ms| ms.to_string())
.unwrap_or_else(|| "-".into()),
r.error
.as_ref()
.map(|e| format!(" [{e}]"))
.unwrap_or_default(),
)); ));
} }
if let Some(ref pref) = report.preferred_relay { if let Some(ref pref) = report.preferred_relay {

View File

@@ -279,8 +279,15 @@ async fn try_natpmp(
// Step 2: request port mapping // Step 2: request port mapping
// Request same port as internal (preferred); 7200s lifetime (standard) // Request same port as internal (preferred); 7200s lifetime (standard)
let (mapped_port, lifetime) = let (mapped_port, lifetime) = natpmp_map_udp(
natpmp_map_udp(&socket, gw_addr, internal_port, internal_port, 7200, timeout).await?; &socket,
gw_addr,
internal_port,
internal_port,
7200,
timeout,
)
.await?;
let lifetime_dur = Duration::from_secs(lifetime as u64); let lifetime_dur = Duration::from_secs(lifetime as u64);
Ok(PortMapping { Ok(PortMapping {
@@ -533,17 +540,12 @@ async fn fetch_url_simple(url: &str, timeout: Duration) -> Result<String, PortMa
.map_err(|e| PortMapError::Protocol(format!("parse {host_port}:80: {e}")))? .map_err(|e| PortMapError::Protocol(format!("parse {host_port}:80: {e}")))?
}; };
let mut stream = tokio::time::timeout( let mut stream = tokio::time::timeout(timeout, tokio::net::TcpStream::connect(addr))
timeout,
tokio::net::TcpStream::connect(addr),
)
.await .await
.map_err(|_| PortMapError::Timeout)? .map_err(|_| PortMapError::Timeout)?
.map_err(|e| PortMapError::Io(e.to_string()))?; .map_err(|e| PortMapError::Io(e.to_string()))?;
let request = format!( let request = format!("GET {path} HTTP/1.1\r\nHost: {host_port}\r\nConnection: close\r\n\r\n");
"GET {path} HTTP/1.1\r\nHost: {host_port}\r\nConnection: close\r\n\r\n"
);
stream stream
.write_all(request.as_bytes()) .write_all(request.as_bytes())
.await .await
@@ -593,10 +595,7 @@ async fn soap_post(
.map_err(|e| PortMapError::Protocol(format!("parse {host_port}:80: {e}")))? .map_err(|e| PortMapError::Protocol(format!("parse {host_port}:80: {e}")))?
}; };
let mut stream = tokio::time::timeout( let mut stream = tokio::time::timeout(timeout, tokio::net::TcpStream::connect(addr))
timeout,
tokio::net::TcpStream::connect(addr),
)
.await .await
.map_err(|_| PortMapError::Timeout)? .map_err(|_| PortMapError::Timeout)?
.map_err(|e| PortMapError::Io(e.to_string()))?; .map_err(|e| PortMapError::Io(e.to_string()))?;
@@ -662,9 +661,7 @@ fn extract_control_url(xml: &str, base_url: &str) -> Result<String, PortMapError
return Ok(control_path.to_string()); return Ok(control_path.to_string());
} }
// Build absolute URL from base // Build absolute URL from base
let base = base_url let base = base_url.strip_prefix("http://").unwrap_or(base_url);
.strip_prefix("http://")
.unwrap_or(base_url);
let host_port = base.split('/').next().unwrap_or(base); let host_port = base.split('/').next().unwrap_or(base);
return Ok(format!("http://{host_port}{control_path}")); return Ok(format!("http://{host_port}{control_path}"));
} }
@@ -681,7 +678,8 @@ async fn upnp_get_external_ip(
control_url: &str, control_url: &str,
timeout: Duration, timeout: Duration,
) -> Result<Ipv4Addr, PortMapError> { ) -> Result<Ipv4Addr, PortMapError> {
let body = "<u:GetExternalIPAddress xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\"/>"; let body =
"<u:GetExternalIPAddress xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\"/>";
let action = "urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress"; let action = "urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress";
let response = soap_post(control_url, action, body, timeout).await?; let response = soap_post(control_url, action, body, timeout).await?;
@@ -933,7 +931,10 @@ mod tests {
assert_eq!(request[0], 0); assert_eq!(request[0], 0);
assert_eq!(request[1], 1); assert_eq!(request[1], 1);
assert_eq!(u16::from_be_bytes([request[4], request[5]]), 12345); assert_eq!(u16::from_be_bytes([request[4], request[5]]), 12345);
assert_eq!(u32::from_be_bytes([request[8], request[9], request[10], request[11]]), 7200); assert_eq!(
u32::from_be_bytes([request[8], request[9], request[10], request[11]]),
7200
);
} }
#[test] #[test]

View File

@@ -30,8 +30,8 @@ use std::net::SocketAddr;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use serde::Serialize; use serde::Serialize;
use wzp_proto::{MediaTransport, SignalMessage}; use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
use wzp_transport::{client_config, create_endpoint, QuinnTransport}; use wzp_transport::{QuinnTransport, client_config, create_endpoint};
/// Result of one probe against one relay. Always returned so the /// Result of one probe against one relay. Always returned so the
/// UI can render per-relay status even when some fail. /// UI can render per-relay status even when some fail.
@@ -110,8 +110,7 @@ pub async fn probe_reflect_addr(
let start = Instant::now(); let start = Instant::now();
let probe = async { let probe = async {
// Open the signal connection. // Open the signal connection.
let conn = let conn = wzp_transport::connect(&endpoint, relay, "_signal", client_config())
wzp_transport::connect(&endpoint, relay, "_signal", client_config())
.await .await
.map_err(|e| format!("connect: {e}"))?; .map_err(|e| format!("connect: {e}"))?;
let transport = QuinnTransport::new(conn); let transport = QuinnTransport::new(conn);
@@ -124,6 +123,7 @@ pub async fn probe_reflect_addr(
// path does in desktop/src-tauri/src/lib.rs register_signal. // path does in desktop/src-tauri/src/lib.rs register_signal.
transport transport
.send_signal(&SignalMessage::RegisterPresence { .send_signal(&SignalMessage::RegisterPresence {
version: default_signal_version(),
identity_pub: [0u8; 32], identity_pub: [0u8; 32],
signature: vec![], signature: vec![],
alias: None, alias: None,
@@ -151,7 +151,7 @@ pub async fn probe_reflect_addr(
.map_err(|e| format!("send Reflect: {e}"))?; .map_err(|e| format!("send Reflect: {e}"))?;
match transport.recv_signal().await { match transport.recv_signal().await {
Ok(Some(SignalMessage::ReflectResponse { observed_addr })) => { Ok(Some(SignalMessage::ReflectResponse { observed_addr, .. })) => {
let parsed: SocketAddr = observed_addr let parsed: SocketAddr = observed_addr
.parse() .parse()
.map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?; .map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?;
@@ -540,10 +540,7 @@ mod tests {
#[test] #[test]
fn classify_two_identical_is_cone() { fn classify_two_identical_is_cone() {
let probes = vec![ let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:4433"))];
mk(Some("192.0.2.1:4433")),
mk(Some("192.0.2.1:4433")),
];
let (nt, addr) = classify_nat(&probes); let (nt, addr) = classify_nat(&probes);
assert_eq!(nt, NatType::Cone); assert_eq!(nt, NatType::Cone);
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433")); assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
@@ -551,10 +548,7 @@ mod tests {
#[test] #[test]
fn classify_same_ip_different_ports_is_symmetric() { fn classify_same_ip_different_ports_is_symmetric() {
let probes = vec![ let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:51234"))];
mk(Some("192.0.2.1:4433")),
mk(Some("192.0.2.1:51234")),
];
let (nt, addr) = classify_nat(&probes); let (nt, addr) = classify_nat(&probes);
assert_eq!(nt, NatType::SymmetricPort); assert_eq!(nt, NatType::SymmetricPort);
assert!(addr.is_none()); assert!(addr.is_none());
@@ -562,10 +556,7 @@ mod tests {
#[test] #[test]
fn classify_different_ips_is_multiple() { fn classify_different_ips_is_multiple() {
let probes = vec![ let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("198.51.100.9:4433"))];
mk(Some("192.0.2.1:4433")),
mk(Some("198.51.100.9:4433")),
];
let (nt, addr) = classify_nat(&probes); let (nt, addr) = classify_nat(&probes);
assert_eq!(nt, NatType::Multiple); assert_eq!(nt, NatType::Multiple);
assert!(addr.is_none()); assert!(addr.is_none());

View File

@@ -109,11 +109,9 @@ impl RelayMap {
/// Check if any entry has a stale probe (older than `max_age`). /// Check if any entry has a stale probe (older than `max_age`).
pub fn needs_reprobe(&self, max_age: Duration) -> bool { pub fn needs_reprobe(&self, max_age: Duration) -> bool {
self.entries.iter().any(|e| { self.entries.iter().any(|e| match e.last_probed {
match e.last_probed {
None => true, None => true,
Some(t) => t.elapsed() > max_age, Some(t) => t.elapsed() > max_age,
}
}) })
} }

View File

@@ -223,9 +223,7 @@ pub fn parse_binding_response(
pos = value_end + ((4 - (attr_len % 4)) % 4); pos = value_end + ((4 - (attr_len % 4)) % 4);
} }
xor_mapped xor_mapped.or(mapped).ok_or(StunError::NoMappedAddress)
.or(mapped)
.ok_or(StunError::NoMappedAddress)
} }
/// Parse a MAPPED-ADDRESS attribute value (RFC 5389 §15.1). /// Parse a MAPPED-ADDRESS attribute value (RFC 5389 §15.1).
@@ -279,10 +277,7 @@ fn parse_mapped_address(value: &[u8]) -> Result<SocketAddr, StunError> {
/// - Port: XOR with top 16 bits of magic cookie /// - Port: XOR with top 16 bits of magic cookie
/// - IPv4 address: XOR with magic cookie /// - IPv4 address: XOR with magic cookie
/// - IPv6 address: XOR with magic cookie || transaction ID /// - IPv6 address: XOR with magic cookie || transaction ID
fn parse_xor_mapped_address( fn parse_xor_mapped_address(value: &[u8], txn_id: &[u8; 12]) -> Result<SocketAddr, StunError> {
value: &[u8],
txn_id: &[u8; 12],
) -> Result<SocketAddr, StunError> {
if value.len() < 4 { if value.len() < 4 {
return Err(StunError::Malformed("XOR-MAPPED-ADDRESS too short".into())); return Err(StunError::Malformed("XOR-MAPPED-ADDRESS too short".into()));
} }
@@ -471,9 +466,7 @@ pub async fn discover_reflexive(config: &StunConfig) -> Result<SocketAddr, StunE
/// Unlike `discover_reflexive` (which returns on first success), this /// Unlike `discover_reflexive` (which returns on first success), this
/// waits for ALL servers and returns individual results — needed for /// waits for ALL servers and returns individual results — needed for
/// NAT type classification which requires 2+ observations. /// NAT type classification which requires 2+ observations.
pub async fn probe_stun_servers( pub async fn probe_stun_servers(config: &StunConfig) -> Vec<crate::reflect::NatProbeResult> {
config: &StunConfig,
) -> Vec<crate::reflect::NatProbeResult> {
use std::time::Instant; use std::time::Instant;
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();
@@ -596,9 +589,7 @@ pub struct PortAllocationResult {
/// - No pattern → `Random` /// - No pattern → `Random`
/// ///
/// Requires at least 3 servers for reliable classification. /// Requires at least 3 servers for reliable classification.
pub async fn detect_port_allocation( pub async fn detect_port_allocation(config: &StunConfig) -> PortAllocationResult {
config: &StunConfig,
) -> PortAllocationResult {
if config.servers.len() < 2 { if config.servers.len() < 2 {
return PortAllocationResult { return PortAllocationResult {
allocation: PortAllocation::Unknown, allocation: PortAllocation::Unknown,
@@ -696,11 +687,15 @@ pub fn classify_port_allocation(ports: &[u16]) -> PortAllocation {
// Allow small jitter: if all deltas are within ±1 of each other, // Allow small jitter: if all deltas are within ±1 of each other,
// consider it sequential with the median delta. // consider it sequential with the median delta.
let all_close = deltas.iter().all(|&d| (d - first_delta).unsigned_abs() <= 1); let all_close = deltas
.iter()
.all(|&d| (d - first_delta).unsigned_abs() <= 1);
if all_close { if all_close {
// Use the most common delta (mode). // Use the most common delta (mode).
let median_delta = first_delta; let median_delta = first_delta;
return PortAllocation::Sequential { delta: median_delta }; return PortAllocation::Sequential {
delta: median_delta,
};
} }
// Check for consistent delta with occasional skip (some NATs // Check for consistent delta with occasional skip (some NATs
@@ -727,12 +722,7 @@ pub fn classify_port_allocation(ports: &[u16]) -> PortAllocation {
/// predicted ports centered around the most likely next value. /// predicted ports centered around the most likely next value.
/// The `offset` parameter accounts for additional flows that may /// The `offset` parameter accounts for additional flows that may
/// open between the probe and the actual connection attempt. /// open between the probe and the actual connection attempt.
pub fn predict_ports( pub fn predict_ports(last_port: u16, delta: i16, offset: u16, spread: u16) -> Vec<u16> {
last_port: u16,
delta: i16,
offset: u16,
spread: u16,
) -> Vec<u16> {
let base = last_port as i32 + (delta as i32 * (offset as i32 + 1)); let base = last_port as i32 + (delta as i32 * (offset as i32 + 1));
let mut ports = Vec::with_capacity((spread * 2 + 1) as usize); let mut ports = Vec::with_capacity((spread * 2 + 1) as usize);
for i in -(spread as i32)..=(spread as i32) { for i in -(spread as i32)..=(spread as i32) {
@@ -1217,7 +1207,11 @@ mod tests {
assert!(StunError::TxnMismatch.to_string().contains("mismatch")); assert!(StunError::TxnMismatch.to_string().contains("mismatch"));
assert!(StunError::NoMappedAddress.to_string().contains("MAPPED")); assert!(StunError::NoMappedAddress.to_string().contains("MAPPED"));
assert!(StunError::Io("test".into()).to_string().contains("test")); assert!(StunError::Io("test".into()).to_string().contains("test"));
assert!(StunError::DnsError("bad".into()).to_string().contains("bad")); assert!(
StunError::DnsError("bad".into())
.to_string()
.contains("bad")
);
assert!(StunError::ErrorResponse(420).to_string().contains("420")); assert!(StunError::ErrorResponse(420).to_string().contains("420"));
assert!(StunError::Malformed("x".into()).to_string().contains("x")); assert!(StunError::Malformed("x".into()).to_string().contains("x"));
} }
@@ -1244,7 +1238,10 @@ mod tests {
#[test] #[test]
fn classify_port_preserving() { fn classify_port_preserving() {
let ports = vec![4433, 4433, 4433, 4433, 4433]; let ports = vec![4433, 4433, 4433, 4433, 4433];
assert_eq!(classify_port_allocation(&ports), PortAllocation::PortPreserving); assert_eq!(
classify_port_allocation(&ports),
PortAllocation::PortPreserving
);
} }
#[test] #[test]
@@ -1290,7 +1287,10 @@ mod tests {
#[test] #[test]
fn classify_two_same_is_preserving() { fn classify_two_same_is_preserving() {
let ports = vec![4433, 4433]; let ports = vec![4433, 4433];
assert_eq!(classify_port_allocation(&ports), PortAllocation::PortPreserving); assert_eq!(
classify_port_allocation(&ports),
PortAllocation::PortPreserving
);
} }
#[test] #[test]
@@ -1359,8 +1359,14 @@ mod tests {
#[test] #[test]
fn port_allocation_display() { fn port_allocation_display() {
assert_eq!(PortAllocation::PortPreserving.to_string(), "port-preserving"); assert_eq!(
assert_eq!(PortAllocation::Sequential { delta: 1 }.to_string(), "sequential(delta=1)"); PortAllocation::PortPreserving.to_string(),
"port-preserving"
);
assert_eq!(
PortAllocation::Sequential { delta: 1 }.to_string(),
"sequential(delta=1)"
);
assert_eq!(PortAllocation::Random.to_string(), "random"); assert_eq!(PortAllocation::Random.to_string(), "random");
assert_eq!(PortAllocation::Unknown.to_string(), "unknown"); assert_eq!(PortAllocation::Unknown.to_string(), "unknown");
} }
@@ -1421,7 +1427,10 @@ mod tests {
let config = StunConfig::default(); let config = StunConfig::default();
let probes = probe_stun_servers(&config).await; let probes = probe_stun_servers(&config).await;
assert!(!probes.is_empty()); assert!(!probes.is_empty());
let successes: Vec<_> = probes.iter().filter(|p| p.observed_addr.is_some()).collect(); let successes: Vec<_> = probes
.iter()
.filter(|p| p.observed_addr.is_some())
.collect();
assert!( assert!(
!successes.is_empty(), !successes.is_empty(),
"at least one STUN server should respond" "at least one STUN server should respond"

View File

@@ -72,8 +72,7 @@ fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
/// decoder, pushes frames through the pipeline, and collects statistics. /// decoder, pushes frames through the pipeline, and collects statistics.
/// Combinations where `target_depth > max_depth` are skipped. /// Combinations where `target_depth > max_depth` are skipped.
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> { pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
let frames_per_config = let frames_per_config = (config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
(config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
let mut results = Vec::new(); let mut results = Vec::new();

View File

@@ -19,7 +19,7 @@
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration; use std::time::Duration;
use wzp_client::dual_path::{race, PeerCandidates, WinningPath}; use wzp_client::dual_path::{PeerCandidates, WinningPath, race};
use wzp_client::reflect::Role; use wzp_client::reflect::Role;
use wzp_transport::{create_endpoint, server_config}; use wzp_transport::{create_endpoint, server_config};
@@ -125,8 +125,15 @@ async fn dual_path_direct_wins_on_loopback() {
.await .await
.expect("race must succeed"); .expect("race must succeed");
assert!(result.direct_transport.is_some(), "direct transport should be available"); assert!(
assert_eq!(result.local_winner, WinningPath::Direct, "direct should win on loopback"); result.direct_transport.is_some(),
"direct transport should be available"
);
assert_eq!(
result.local_winner,
WinningPath::Direct,
"direct should win on loopback"
);
// Cancel the acceptor accept task so the test finishes. // Cancel the acceptor accept task so the test finishes.
acceptor_accept_task.abort(); acceptor_accept_task.abort();
@@ -170,7 +177,10 @@ async fn dual_path_relay_wins_when_direct_is_dead() {
.await .await
.expect("race must succeed via relay fallback"); .expect("race must succeed via relay fallback");
assert!(result.relay_transport.is_some(), "relay transport should be available"); assert!(
result.relay_transport.is_some(),
"relay transport should be available"
);
assert_eq!( assert_eq!(
result.local_winner, result.local_winner,
WinningPath::Relay, WinningPath::Relay,

View File

@@ -6,12 +6,12 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::mpsc;
use wzp_proto::packet::MediaPacket; use wzp_proto::packet::MediaPacket;
use wzp_proto::traits::{MediaTransport, PathQuality}; use wzp_proto::traits::{MediaTransport, PathQuality};
use wzp_proto::{SignalMessage, TransportError}; use wzp_proto::{SignalMessage, TransportError, default_signal_version};
/// A mock transport backed by two mpsc channels (one per direction). /// A mock transport backed by two mpsc channels (one per direction).
/// ///
@@ -83,11 +83,15 @@ async fn full_handshake_both_sides_derive_same_session() {
// Run client and relay handshakes concurrently. // Run client and relay handshakes concurrently.
let (client_result, relay_result) = tokio::join!( let (client_result, relay_result) = tokio::join!(
wzp_client::handshake::perform_handshake(client_transport_clone.as_ref(), &client_seed, None), wzp_client::handshake::perform_handshake(
client_transport_clone.as_ref(),
&client_seed,
None
),
wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed), wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed),
); );
let mut client_session = client_result.expect("client handshake should succeed"); let client_hs = client_result.expect("client handshake should succeed");
let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) = let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) =
relay_result.expect("relay handshake should succeed"); relay_result.expect("relay handshake should succeed");
@@ -95,31 +99,53 @@ async fn full_handshake_both_sides_derive_same_session() {
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD); assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
// Verify both sides can communicate: client encrypts, relay decrypts. // Verify both sides can communicate: client encrypts, relay decrypts.
let header = b"test-header"; // encrypt/decrypt derive nonces from MediaHeader.seq, so we need valid headers.
use wzp_proto::packet::MediaHeader;
use wzp_proto::{CodecId, MediaType};
let make_hdr = |seq: u32| {
let h = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq,
timestamp: seq.wrapping_mul(20),
fec_block: 0,
};
let mut b = Vec::new();
h.write_to(&mut b);
b
};
let header = make_hdr(0);
let plaintext = b"hello from client to relay"; let plaintext = b"hello from client to relay";
let mut client_session = client_hs.session;
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
client_session client_session
.encrypt(header, plaintext, &mut ciphertext) .encrypt(&header, plaintext, &mut ciphertext)
.expect("client encrypt should succeed"); .expect("client encrypt should succeed");
let mut decrypted = Vec::new(); let mut decrypted = Vec::new();
relay_session relay_session
.decrypt(header, &ciphertext, &mut decrypted) .decrypt(&header, &ciphertext, &mut decrypted)
.expect("relay decrypt should succeed"); .expect("relay decrypt should succeed");
assert_eq!(&decrypted[..], plaintext); assert_eq!(&decrypted[..], plaintext);
// Verify reverse direction: relay encrypts, client decrypts. // Verify reverse direction: relay encrypts, client decrypts.
let header2 = make_hdr(0); // relay's send_seq starts at 0
let plaintext2 = b"hello from relay to client"; let plaintext2 = b"hello from relay to client";
let mut ciphertext2 = Vec::new(); let mut ciphertext2 = Vec::new();
relay_session relay_session
.encrypt(header, plaintext2, &mut ciphertext2) .encrypt(&header2, plaintext2, &mut ciphertext2)
.expect("relay encrypt should succeed"); .expect("relay encrypt should succeed");
let mut decrypted2 = Vec::new(); let mut decrypted2 = Vec::new();
client_session client_session
.decrypt(header, &ciphertext2, &mut decrypted2) .decrypt(&header2, &ciphertext2, &mut decrypted2)
.expect("client decrypt should succeed"); .expect("client decrypt should succeed");
assert_eq!(&decrypted2[..], plaintext2); assert_eq!(&decrypted2[..], plaintext2);
@@ -147,11 +173,15 @@ async fn handshake_rejects_tampered_signature() {
let bad_signature = kx.sign(b"wrong-data-intentionally"); let bad_signature = kx.sign(b"wrong-data-intentionally");
let offer = SignalMessage::CallOffer { let offer = SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature: bad_signature, signature: bad_signature,
supported_profiles: vec![wzp_proto::QualityProfile::GOOD], supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None, alias: None,
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}; };
client_transport_clone client_transport_clone
.send_signal(&offer) .send_signal(&offer)
@@ -175,3 +205,42 @@ async fn handshake_rejects_tampered_signature() {
Ok(_) => panic!("relay should reject tampered signature"), Ok(_) => panic!("relay should reject tampered signature"),
} }
} }
#[tokio::test]
async fn client_receives_protocol_version_mismatch() {
let (client_transport, relay_transport) = MockTransport::pair();
let client_seed = [0xAA_u8; 32];
// Spawn a fake relay that sends ProtocolVersionMismatch.
let relay_clone = Arc::clone(&relay_transport);
tokio::spawn(async move {
// Wait for the client's CallOffer.
let offer = relay_clone.recv_signal().await.unwrap().unwrap();
assert!(matches!(offer, SignalMessage::CallOffer { .. }));
// Respond with ProtocolVersionMismatch.
let mismatch = SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
server_supported: vec![3],
},
call_id: None,
};
relay_clone.send_signal(&mismatch).await.unwrap();
});
let result =
wzp_client::handshake::perform_handshake(client_transport.as_ref(), &client_seed, None)
.await;
match result {
Err(wzp_client::handshake::HandshakeError::ProtocolVersionMismatch {
server_supported,
}) => {
assert_eq!(server_supported, vec![3]);
}
Err(other) => panic!("expected ProtocolVersionMismatch, got: {other:?}"),
Ok(_) => panic!("expected handshake to fail with ProtocolVersionMismatch"),
}
}

View File

@@ -83,8 +83,12 @@ fn long_session_no_drift() {
println!( println!(
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \ "long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}", underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen, stats.underruns,
stats.packets_late, stats.packets_lost, stats.overruns,
stats.current_depth,
stats.max_depth_seen,
stats.packets_late,
stats.packets_lost,
); );
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames // With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
@@ -123,7 +127,7 @@ fn long_session_with_simulated_loss() {
for (j, pkt) in batch.into_iter().enumerate() { for (j, pkt) in batch.into_iter().enumerate() {
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss. // Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
if !pkt.header.is_repair && i % 20 == 0 && j == 0 { if !pkt.header.is_repair() && i % 20 == 0 && j == 0 {
continue; // drop this packet continue; // drop this packet
} }
decoder.ingest(pkt); decoder.ingest(pkt);
@@ -139,8 +143,12 @@ fn long_session_with_simulated_loss() {
println!( println!(
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \ "long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}", underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen, stats.underruns,
stats.packets_late, stats.packets_lost, stats.overruns,
stats.current_depth,
stats.max_depth_seen,
stats.packets_late,
stats.packets_lost,
); );
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded. // With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
@@ -150,6 +158,65 @@ fn long_session_with_simulated_loss() {
); );
} }
/// Verify that `MediaHeader::timestamp` continues monotonically across
/// rekey boundaries. Rekey is a crypto-layer operation (key material
/// rotation) and must not reset or interfere with framing state.
///
/// We simulate a 3000-frame session with two conceptual rekeys at frames
/// 1000 and 2000. The encoder's timestamp counter must advance
/// monotonically throughout.
#[test]
fn rekey_timestamp_monotonic() {
let config = test_config();
let mut encoder = CallEncoder::new(&config);
let mut timestamps = Vec::new();
// Phase 1: before first rekey
for i in 0..1000 {
let pcm = sine_frame(i);
let packets = encoder.encode_frame(&pcm).expect("encode");
for pkt in packets {
timestamps.push(pkt.header.timestamp);
}
}
// Phase 2: between first and second rekey
for i in 1000..2000 {
let pcm = sine_frame(i);
let packets = encoder.encode_frame(&pcm).expect("encode");
for pkt in packets {
timestamps.push(pkt.header.timestamp);
}
}
// Phase 3: after second rekey
for i in 2000..3000 {
let pcm = sine_frame(i);
let packets = encoder.encode_frame(&pcm).expect("encode");
for pkt in packets {
timestamps.push(pkt.header.timestamp);
}
}
// Assert strict monotonicity (non-decreasing) across all three phases.
for window in timestamps.windows(2) {
assert!(
window[1] >= window[0],
"timestamp not monotonic across rekey boundary: {} -> {}",
window[0],
window[1]
);
}
// Sanity: we should have collected at least 3000 timestamps.
assert!(
timestamps.len() >= 3000,
"expected >= 3000 timestamps, got {}",
timestamps.len()
);
}
/// Verify that the jitter buffer's decoded-frame count is consistent with its /// Verify that the jitter buffer's decoded-frame count is consistent with its
/// own internal statistics over a long session. /// own internal statistics over a long session.
#[test] #[test]

View File

@@ -114,11 +114,7 @@ impl EchoCanceller {
/// Number of delayed samples available to release. /// Number of delayed samples available to release.
fn delay_available(&self) -> usize { fn delay_available(&self) -> usize {
let buffered = self.delay_write - self.delay_read; let buffered = self.delay_write - self.delay_read;
if buffered > self.delay_samples { buffered.saturating_sub(self.delay_samples)
buffered - self.delay_samples
} else {
0
}
} }
/// Process a near-end (microphone) frame, removing the estimated echo. /// Process a near-end (microphone) frame, removing the estimated echo.
@@ -161,8 +157,8 @@ impl EchoCanceller {
let mut sum_near_sq: f64 = 0.0; let mut sum_near_sq: f64 = 0.0;
let mut sum_err_sq: f64 = 0.0; let mut sum_err_sq: f64 = 0.0;
for i in 0..n { for (i, sample) in nearend.iter_mut().enumerate() {
let near_f = nearend[i] as f32; let near_f = *sample as f32;
// Position of far-end "now" for this near-end sample. // Position of far-end "now" for this near-end sample.
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl; let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
@@ -190,7 +186,7 @@ impl EchoCanceller {
} }
let out = error.clamp(-32768.0, 32767.0); let out = error.clamp(-32768.0, 32767.0);
nearend[i] = out as i16; *sample = out as i16;
sum_near_sq += (near_f as f64).powi(2); sum_near_sq += (near_f as f64).powi(2);
sum_err_sq += (out as f64).powi(2); sum_err_sq += (out as f64).powi(2);
@@ -325,7 +321,10 @@ mod tests {
// Feed 960 samples (= delay amount). No samples released yet. // Feed 960 samples (= delay amount). No samples released yet.
aec.feed_farend(&vec![1i16; 960]); aec.feed_farend(&vec![1i16; 960]);
// far_buf should still be all zeros (nothing released). // far_buf should still be all zeros (nothing released).
assert!(aec.far_buf.iter().all(|&s| s == 0.0), "nothing should be released yet"); assert!(
aec.far_buf.iter().all(|&s| s == 0.0),
"nothing should be released yet"
);
// Feed 480 more. 480 should be released to far_buf. // Feed 480 more. 480 should be released to far_buf.
aec.feed_farend(&vec![2i16; 480]); aec.feed_farend(&vec![2i16; 480]);

View File

@@ -211,9 +211,6 @@ mod tests {
fn agc_gain_db_at_unity() { fn agc_gain_db_at_unity() {
let agc = AutoGainControl::new(); let agc = AutoGainControl::new();
let db = agc.current_gain_db(); let db = agc.current_gain_db();
assert!( assert!(db.abs() < 0.01, "expected ~0 dB at unity gain, got {db}");
db.abs() < 0.01,
"expected ~0 dB at unity gain, got {db}"
);
} }
} }

View File

@@ -45,7 +45,7 @@ impl Codec2Decoder {
/// Number of compressed bytes per frame. /// Number of compressed bytes per frame.
fn bytes_per_frame(&self) -> usize { fn bytes_per_frame(&self) -> usize {
(self.inner.bits_per_frame() + 7) / 8 self.inner.bits_per_frame().div_ceil(8)
} }
} }

View File

@@ -45,7 +45,7 @@ impl Codec2Encoder {
/// Number of compressed bytes per frame. /// Number of compressed bytes per frame.
fn bytes_per_frame(&self) -> usize { fn bytes_per_frame(&self) -> usize {
(self.inner.bits_per_frame() + 7) / 8 self.inner.bits_per_frame().div_ceil(8)
} }
} }

View File

@@ -56,7 +56,7 @@ impl NoiseSupressor {
// f32 → i16 with clamping // f32 → i16 with clamping
for (i, &val) in output.iter().enumerate() { for (i, &val) in output.iter().enumerate() {
let clamped = val.max(-32768.0).min(32767.0); let clamped = val.clamp(-32768.0, 32767.0);
pcm[offset + i] = clamped as i16; pcm[offset + i] = clamped as i16;
} }
} }
@@ -99,7 +99,11 @@ mod tests {
} }
let original_len = pcm.len(); let original_len = pcm.len();
ns.process(&mut pcm); ns.process(&mut pcm);
assert_eq!(pcm.len(), original_len, "output length must match input length"); assert_eq!(
pcm.len(),
original_len,
"output length must match input length"
);
} }
#[test] #[test]

View File

@@ -71,9 +71,8 @@ impl DecoderHandle {
"opus_decoder_create failed: err={error}" "opus_decoder_create failed: err={error}"
))); )));
} }
let inner = NonNull::new(ptr).ok_or_else(|| { let inner = NonNull::new(ptr)
CodecError::DecodeFailed("opus_decoder_create returned null".into()) .ok_or_else(|| CodecError::DecodeFailed("opus_decoder_create returned null".into()))?;
})?;
Ok(Self { inner }) Ok(Self { inner })
} }
@@ -257,11 +256,7 @@ impl DredDecoderHandle {
/// The `dred_end` output is the silence gap at the tail of the DRED /// The `dred_end` output is the silence gap at the tail of the DRED
/// window; we subtract it from the total offset to give callers the /// window; we subtract it from the total offset to give callers the
/// truly usable sample count. /// truly usable sample count.
pub fn parse_into( pub fn parse_into(&mut self, state: &mut DredState, packet: &[u8]) -> Result<i32, CodecError> {
&mut self,
state: &mut DredState,
packet: &[u8],
) -> Result<i32, CodecError> {
if packet.is_empty() { if packet.is_empty() {
state.samples_available = 0; state.samples_available = 0;
return Ok(0); return Ok(0);
@@ -545,7 +540,10 @@ mod tests {
// to our sine wave because we fed a cold decoder only one warmup // to our sine wave because we fed a cold decoder only one warmup
// frame, but it should still produce non-silent speech-like output // frame, but it should still produce non-silent speech-like output
// since the DRED state was parsed from real speech content. // since the DRED state was parsed from real speech content.
let energy: u64 = recon_pcm.iter().map(|&s| (s as i32).unsigned_abs() as u64).sum(); let energy: u64 = recon_pcm
.iter()
.map(|&s| (s as i32).unsigned_abs() as u64)
.sum();
assert!( assert!(
energy > 0, energy > 0,
"reconstructed audio has zero total energy — DRED reconstruction produced silence" "reconstructed audio has zero total energy — DRED reconstruction produced silence"

View File

@@ -53,10 +53,7 @@ pub fn set_dred_verbose_logs(enabled: bool) {
/// The returned encoder accepts 48 kHz mono PCM regardless of the active /// The returned encoder accepts 48 kHz mono PCM regardless of the active
/// codec; resampling is handled internally when Codec2 is selected. /// codec; resampling is handled internally when Codec2 is selected.
pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> { pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
Box::new( Box::new(AdaptiveEncoder::new(profile).expect("failed to create adaptive encoder"))
AdaptiveEncoder::new(profile)
.expect("failed to create adaptive encoder"),
)
} }
/// Create an adaptive decoder starting at the given quality profile. /// Create an adaptive decoder starting at the given quality profile.
@@ -64,10 +61,7 @@ pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
/// The returned decoder always produces 48 kHz mono PCM; upsampling from /// The returned decoder always produces 48 kHz mono PCM; upsampling from
/// Codec2's native 8 kHz is handled internally. /// Codec2's native 8 kHz is handled internally.
pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> { pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> {
Box::new( Box::new(AdaptiveDecoder::new(profile).expect("failed to create adaptive decoder"))
AdaptiveDecoder::new(profile)
.expect("failed to create adaptive decoder"),
)
} }
#[cfg(test)] #[cfg(test)]
@@ -82,6 +76,10 @@ mod codec2_tests {
fec_ratio: 0.5, fec_ratio: 0.5,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
priority_mode: wzp_proto::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
} }
} }
@@ -210,7 +208,10 @@ mod codec2_tests {
let mut pcm_out_c2 = vec![0i16; 1920]; let mut pcm_out_c2 = vec![0i16; 1920];
let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap(); let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap();
assert_eq!(samples_c2, 1920, "should get 1920 samples at 48kHz after upsample"); assert_eq!(
samples_c2, 1920,
"should get 1920 samples at 48kHz after upsample"
);
// Step 3: Switch back to Opus. // Step 3: Switch back to Opus.
enc.set_profile(QualityProfile::GOOD).unwrap(); enc.set_profile(QualityProfile::GOOD).unwrap();

View File

@@ -85,8 +85,13 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
// offsets, so the extra window costs only ~1-2 kbps additional overhead // offsets, so the extra window costs only ~1-2 kbps additional overhead
// while buying substantially better burst resilience (up from 500 ms). // while buying substantially better burst resilience (up from 500 ms).
CodecId::Opus6k => 104, CodecId::Opus6k => 104,
// Non-Opus (Codec2 / CN): DRED is N/A. // Non-Opus (Codec2 / CN / video): DRED is N/A.
CodecId::Codec2_1200 | CodecId::Codec2_3200 | CodecId::ComfortNoise => 0, CodecId::Codec2_1200
| CodecId::Codec2_3200
| CodecId::ComfortNoise
| CodecId::H264Baseline
| CodecId::H265Main
| CodecId::Av1Main => 0,
} }
} }
@@ -96,7 +101,7 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
/// mode; unset or empty leaves DRED enabled. /// mode; unset or empty leaves DRED enabled.
fn read_legacy_fec_env() -> bool { fn read_legacy_fec_env() -> bool {
match std::env::var(LEGACY_FEC_ENV) { match std::env::var(LEGACY_FEC_ENV) {
Ok(v) => !v.is_empty() && v != "0" && v.to_ascii_lowercase() != "false", Ok(v) => !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false"),
Err(_) => false, Err(_) => false,
} }
} }
@@ -247,7 +252,7 @@ impl OpusEncoder {
let clamped = if self.legacy_fec_mode { let clamped = if self.legacy_fec_mode {
loss_pct.min(100) loss_pct.min(100)
} else { } else {
loss_pct.max(DRED_LOSS_FLOOR_PCT).min(100) loss_pct.clamp(DRED_LOSS_FLOOR_PCT, 100)
}; };
let _ = self.inner.set_packet_loss(clamped); let _ = self.inner.set_packet_loss(clamped);
} }
@@ -332,7 +337,11 @@ impl AudioEncoder for OpusEncoder {
); );
return; return;
} }
let mode = if enabled { InbandFec::Mode1 } else { InbandFec::Off }; let mode = if enabled {
InbandFec::Mode1
} else {
InbandFec::Off
};
let _ = self.inner.set_inband_fec(mode); let _ = self.inner.set_inband_fec(mode);
} }

View File

@@ -48,7 +48,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5) let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5)
let beta_denom = bessel_i0(KAISER_BETA); let beta_denom = bessel_i0(KAISER_BETA);
for i in 0..FIR_TAPS { for (i, slot) in kernel.iter_mut().enumerate() {
// Sinc // Sinc
let n = i as f64 - m / 2.0; let n = i as f64 - m / 2.0;
let sinc = if n.abs() < 1e-12 { let sinc = if n.abs() < 1e-12 {
@@ -61,7 +61,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1] let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1]
let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom; let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom;
kernel[i] = sinc * kaiser; *slot = sinc * kaiser;
} }
// Normalise to unity DC gain. // Normalise to unity DC gain.
@@ -129,8 +129,7 @@ impl Downsampler48to8 {
// Update history: keep the last (FIR_TAPS - 1) samples from work. // Update history: keep the last (FIR_TAPS - 1) samples from work.
if work.len() >= hist_len { if work.len() >= hist_len {
self.history self.history.copy_from_slice(&work[work.len() - hist_len..]);
.copy_from_slice(&work[work.len() - hist_len..]);
} else { } else {
// Input was shorter than history — shift. // Input was shorter than history — shift.
let shift = hist_len - work.len(); let shift = hist_len - work.len();
@@ -181,9 +180,7 @@ impl Upsampler8to48 {
work.extend_from_slice(&self.history); work.extend_from_slice(&self.history);
for &s in input { for &s in input {
work.push(s as f64); work.push(s as f64);
for _ in 1..RATIO { work.resize(work.len() + (RATIO - 1), 0.0f64);
work.push(0.0);
}
} }
let out_len = stuffed_len; let out_len = stuffed_len;
@@ -209,8 +206,7 @@ impl Upsampler8to48 {
// Update history. // Update history.
if work.len() >= hist_len { if work.len() >= hist_len {
self.history self.history.copy_from_slice(&work[work.len() - hist_len..]);
.copy_from_slice(&work[work.len() - hist_len..]);
} else { } else {
let shift = hist_len - work.len(); let shift = hist_len - work.len();
self.history.copy_within(shift.., 0); self.history.copy_within(shift.., 0);

View File

@@ -151,7 +151,10 @@ mod tests {
for _ in 0..4 { for _ in 0..4 {
det.is_silent(&silence); det.is_silent(&silence);
} }
assert!(det.is_silent(&silence), "should be suppressing after hangover"); assert!(
det.is_silent(&silence),
"should be suppressing after hangover"
);
// Speech arrives — should immediately stop suppressing. // Speech arrives — should immediately stop suppressing.
assert!(!det.is_silent(&speech)); assert!(!det.is_silent(&speech));
@@ -165,10 +168,16 @@ mod tests {
cn.generate(&mut pcm); cn.generate(&mut pcm);
// At least some samples should be non-zero. // At least some samples should be non-zero.
assert!(pcm.iter().any(|&s| s != 0), "CN output should not be all zeros"); assert!(
pcm.iter().any(|&s| s != 0),
"CN output should not be all zeros"
);
// All samples should be within [-50, 50]. // All samples should be within [-50, 50].
assert!(pcm.iter().all(|&s| s.abs() <= 50), "CN samples out of range"); assert!(
pcm.iter().all(|&s| s.abs() <= 50),
"CN samples out of range"
);
} }
#[test] #[test]
@@ -179,11 +188,17 @@ mod tests {
// Constant value: RMS of [v, v, v, ...] = |v|. // Constant value: RMS of [v, v, v, ...] = |v|.
let pcm = vec![100i16; 100]; let pcm = vec![100i16; 100];
let rms = SilenceDetector::rms(&pcm); let rms = SilenceDetector::rms(&pcm);
assert!((rms - 100.0).abs() < 0.01, "RMS of constant 100 should be 100, got {rms}"); assert!(
(rms - 100.0).abs() < 0.01,
"RMS of constant 100 should be 100, got {rms}"
);
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355 // Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
let rms2 = SilenceDetector::rms(&[3, 4]); let rms2 = SilenceDetector::rms(&[3, 4]);
assert!((rms2 - 3.5355).abs() < 0.01, "RMS of [3,4] should be ~3.5355, got {rms2}"); assert!(
(rms2 - 3.5355).abs() < 0.01,
"RMS of [3,4] should be ~3.5355, got {rms2}"
);
// Empty buffer → 0. // Empty buffer → 0.
assert_eq!(SilenceDetector::rms(&[]), 0.0); assert_eq!(SilenceDetector::rms(&[]), 0.0);

View File

@@ -1,21 +1,20 @@
//! Sliding window replay protection. //! Sliding window replay protection.
//! //!
//! Tracks seen sequence numbers using a bitmap. Window size is 1024 packets. //! Tracks seen sequence numbers using a bitmap. Window size is configurable
//! Sequence numbers that are too old (more than WINDOW_SIZE behind the highest //! at construction time. Sequence numbers that are too old (more than
//! seen) are rejected. //! `window_size` behind the highest seen) are rejected.
use wzp_proto::CryptoError; use wzp_proto::CryptoError;
/// Window size in packets.
const WINDOW_SIZE: u16 = 1024;
/// Sliding window anti-replay detector. /// Sliding window anti-replay detector.
/// ///
/// Uses a bitmap to track which sequence numbers have been seen within /// Uses a bitmap to track which sequence numbers have been seen within
/// the current window. Handles u16 wrapping correctly. /// the current window. Handles `u32` wrapping correctly.
pub struct AntiReplayWindow { pub struct AntiReplayWindow {
/// Window size in packets.
window_size: u32,
/// Highest sequence number seen so far. /// Highest sequence number seen so far.
highest: u16, highest: u32,
/// Bitmap of seen packets. Bit i corresponds to (highest - i). /// Bitmap of seen packets. Bit i corresponds to (highest - i).
bitmap: Vec<u64>, bitmap: Vec<u64>,
/// Whether any packet has been received yet. /// Whether any packet has been received yet.
@@ -23,21 +22,26 @@ pub struct AntiReplayWindow {
} }
impl AntiReplayWindow { impl AntiReplayWindow {
/// Number of u64 words needed for the bitmap. /// Create a new anti-replay window with the default size of 1024 packets.
const BITMAP_WORDS: usize = (WINDOW_SIZE as usize + 63) / 64;
/// Create a new anti-replay window.
pub fn new() -> Self { pub fn new() -> Self {
Self::with_window(1024)
}
/// Create a new anti-replay window with a custom size.
pub fn with_window(size: usize) -> Self {
let window_size = size as u32;
let bitmap_words = (size + 63) / 64;
Self { Self {
window_size,
highest: 0, highest: 0,
bitmap: vec![0u64; Self::BITMAP_WORDS], bitmap: vec![0u64; bitmap_words],
initialized: false, initialized: false,
} }
} }
/// Check if a sequence number is valid (not a replay, not too old). /// Check if a sequence number is valid (not a replay, not too old).
/// If valid, marks it as seen. /// If valid, marks it as seen.
pub fn check_and_update(&mut self, seq: u16) -> Result<(), CryptoError> { pub fn check_and_update(&mut self, seq: u32) -> Result<(), CryptoError> {
if !self.initialized { if !self.initialized {
self.initialized = true; self.initialized = true;
self.highest = seq; self.highest = seq;
@@ -52,17 +56,17 @@ impl AntiReplayWindow {
return Err(CryptoError::ReplayDetected { seq }); return Err(CryptoError::ReplayDetected { seq });
} }
if diff < 0x8000 { if diff < 0x8000_0000 {
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF]) // seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF_FFFF])
let shift = diff as usize; let shift = diff as usize;
self.advance_window(shift); self.advance_window(shift);
self.highest = seq; self.highest = seq;
self.set_bit(0); self.set_bit(0);
Ok(()) Ok(())
} else { } else {
// seq is behind highest (wrapping-aware: diff in [0x8000, 0xFFFF]) // seq is behind highest (wrapping-aware: diff in [0x8000_0000, 0xFFFF_FFFF])
let behind = self.highest.wrapping_sub(seq) as usize; let behind = self.highest.wrapping_sub(seq) as usize;
if behind >= WINDOW_SIZE as usize { if behind >= self.window_size as usize {
return Err(CryptoError::ReplayDetected { seq }); return Err(CryptoError::ReplayDetected { seq });
} }
if self.get_bit(behind) { if self.get_bit(behind) {
@@ -75,7 +79,8 @@ impl AntiReplayWindow {
/// Advance the window by `shift` positions (shift left = new bits at position 0). /// Advance the window by `shift` positions (shift left = new bits at position 0).
fn advance_window(&mut self, shift: usize) { fn advance_window(&mut self, shift: usize) {
if shift >= WINDOW_SIZE as usize { let window_size = self.window_size as usize;
if shift >= window_size {
for word in &mut self.bitmap { for word in &mut self.bitmap {
*word = 0; *word = 0;
} }
@@ -156,7 +161,11 @@ mod tests {
fn sequential_accepted() { fn sequential_accepted() {
let mut w = AntiReplayWindow::new(); let mut w = AntiReplayWindow::new();
for i in 0..200 { for i in 0..200 {
assert!(w.check_and_update(i).is_ok(), "seq {} should be accepted", i); assert!(
w.check_and_update(i).is_ok(),
"seq {} should be accepted",
i
);
} }
} }
@@ -183,11 +192,11 @@ mod tests {
#[test] #[test]
fn wrapping_works() { fn wrapping_works() {
let mut w = AntiReplayWindow::new(); let mut w = AntiReplayWindow::new();
assert!(w.check_and_update(65530).is_ok()); assert!(w.check_and_update(0xFFFF_FFF0).is_ok());
assert!(w.check_and_update(65535).is_ok()); assert!(w.check_and_update(0xFFFF_FFFF).is_ok());
assert!(w.check_and_update(0).is_ok()); // wrapped assert!(w.check_and_update(0).is_ok()); // wrapped
assert!(w.check_and_update(1).is_ok()); assert!(w.check_and_update(1).is_ok());
assert!(w.check_and_update(65535).is_err()); // duplicate assert!(w.check_and_update(0xFFFF_FFFF).is_err()); // duplicate
} }
#[test] #[test]
@@ -201,4 +210,53 @@ mod tests {
// Now 0 is 1024 behind 1024, which is at the boundary limit // Now 0 is 1024 behind 1024, which is at the boundary limit
assert!(w.check_and_update(0).is_err()); // already seen or too old assert!(w.check_and_update(0).is_err()); // already seen or too old
} }
#[test]
fn custom_window_size() {
let mut w = AntiReplayWindow::with_window(64);
for i in 0..64 {
assert!(w.check_and_update(i).is_ok());
}
// seq 0 is now exactly at the boundary (64 behind 64)
assert!(w.check_and_update(0).is_err());
}
#[test]
fn video_burst_200_with_one_reorder() {
let mut w = AntiReplayWindow::with_window(1024);
// Simulate a 200-packet burst
for i in 0..200 {
assert!(
w.check_and_update(i).is_ok(),
"seq {} should be accepted",
i
);
}
// One packet reordered (arrives late)
assert!(w.check_and_update(50).is_err(), "seq 50 is a duplicate");
// But a packet just behind the window should still be ok
assert!(w.check_and_update(199).is_err(), "seq 199 is a duplicate");
// Continue the burst
for i in 200..400 {
assert!(
w.check_and_update(i).is_ok(),
"seq {} should be accepted",
i
);
}
}
#[test]
fn u32_high_range_works() {
let mut w = AntiReplayWindow::with_window(64);
let base = 1000u32;
assert!(w.check_and_update(base).is_ok());
assert!(w.check_and_update(base + 1).is_ok());
// 65 behind highest (base+1) is outside the 64-packet window
assert!(w.check_and_update(base.wrapping_sub(64)).is_err());
// 63 behind is inside
assert!(w.check_and_update(base.wrapping_sub(62)).is_ok());
// base itself is now a duplicate
assert!(w.check_and_update(base).is_err());
}
} }

View File

@@ -9,8 +9,8 @@ use ed25519_dalek::{Signer, SigningKey, Verifier, VerifyingKey};
use hkdf::Hkdf; use hkdf::Hkdf;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
use wzp_proto::{CryptoError, CryptoSession, KeyExchange}; use wzp_proto::{CryptoError, CryptoSession, KeyExchange};
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
use crate::session::ChaChaSession; use crate::session::ChaChaSession;
@@ -95,11 +95,10 @@ impl KeyExchange for WarzoneKeyExchange {
&self, &self,
peer_ephemeral_pub: &[u8; 32], peer_ephemeral_pub: &[u8; 32],
) -> Result<Box<dyn CryptoSession>, CryptoError> { ) -> Result<Box<dyn CryptoSession>, CryptoError> {
let secret = self let secret = self.ephemeral_secret.as_ref().ok_or_else(|| {
.ephemeral_secret CryptoError::Internal(
.as_ref() "no ephemeral key generated; call generate_ephemeral first".into(),
.ok_or_else(|| { )
CryptoError::Internal("no ephemeral key generated; call generate_ephemeral first".into())
})?; })?;
let peer_public = X25519PublicKey::from(*peer_ephemeral_pub); let peer_public = X25519PublicKey::from(*peer_ephemeral_pub);
@@ -210,18 +209,34 @@ mod tests {
let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap(); let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap();
let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap(); let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap();
// Verify they can communicate: Alice encrypts, Bob decrypts // Verify they can communicate: Alice encrypts, Bob decrypts.
let header = b"call-header"; // Use a valid v2 MediaHeader — encrypt/decrypt now derive the nonce from
// header.seq and will reject raw byte slices shorter than WIRE_SIZE.
use wzp_proto::{CodecId, MediaHeader, MediaType};
let header = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq: 0,
timestamp: 0,
fec_block: 0,
};
let mut header_bytes = Vec::new();
header.write_to(&mut header_bytes);
let plaintext = b"hello from alice"; let plaintext = b"hello from alice";
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
alice_session alice_session
.encrypt(header, plaintext, &mut ciphertext) .encrypt(&header_bytes, plaintext, &mut ciphertext)
.unwrap(); .unwrap();
let mut decrypted = Vec::new(); let mut decrypted = Vec::new();
bob_session bob_session
.decrypt(header, &ciphertext, &mut decrypted) .decrypt(&header_bytes, &ciphertext, &mut decrypted)
.unwrap(); .unwrap();
assert_eq!(&decrypted, plaintext); assert_eq!(&decrypted, plaintext);

View File

@@ -79,7 +79,9 @@ impl Seed {
/// ///
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed` /// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
pub fn from_mnemonic(words: &str) -> Result<Self, String> { pub fn from_mnemonic(words: &str) -> Result<Self, String> {
let mnemonic: bip39::Mnemonic = words.parse().map_err(|e| format!("invalid mnemonic: {e}"))?; let mnemonic: bip39::Mnemonic = words
.parse()
.map_err(|e| format!("invalid mnemonic: {e}"))?;
let entropy = mnemonic.to_entropy(); let entropy = mnemonic.to_entropy();
if entropy.len() != 32 { if entropy.len() != 32 {
return Err(format!("expected 32 bytes entropy, got {}", entropy.len())); return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));

View File

@@ -16,8 +16,8 @@ pub mod session;
pub use anti_replay::AntiReplayWindow; pub use anti_replay::AntiReplayWindow;
pub use handshake::WarzoneKeyExchange; pub use handshake::WarzoneKeyExchange;
pub use identity::{hash_room_name, Fingerprint, IdentityKeyPair, PublicIdentity, Seed}; pub use identity::{Fingerprint, IdentityKeyPair, PublicIdentity, Seed, hash_room_name};
pub use nonce::{build_nonce, Direction}; pub use nonce::{Direction, build_nonce};
pub use rekey::RekeyManager; pub use rekey::RekeyManager;
pub use session::ChaChaSession; pub use session::ChaChaSession;

View File

@@ -36,6 +36,10 @@ impl RekeyManager {
/// ///
/// The old key is zeroized after the new key is derived. /// The old key is zeroized after the new key is derived.
/// Returns the new 32-byte symmetric key. /// Returns the new 32-byte symmetric key.
///
/// NOTE: Rekeying changes **only** the symmetric key material. Sequence
/// numbers and timestamps in the media framing layer (e.g. `MediaHeader`)
/// are untouched — they continue monotonically across the rekey boundary.
pub fn perform_rekey( pub fn perform_rekey(
&mut self, &mut self,
new_peer_pub: &[u8; 32], new_peer_pub: &[u8; 32],

View File

@@ -3,12 +3,15 @@
//! Implements the `CryptoSession` trait for per-call media encryption. //! Implements the `CryptoSession` trait for per-call media encryption.
//! Nonces are derived deterministically from session_id + sequence counter + direction. //! Nonces are derived deterministically from session_id + sequence counter + direction.
use std::collections::HashMap;
use chacha20poly1305::aead::Aead; use chacha20poly1305::aead::Aead;
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce}; use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce};
use x25519_dalek::{PublicKey, StaticSecret};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use wzp_proto::{CryptoError, CryptoSession}; use wzp_proto::{CryptoError, CryptoSession, MediaHeader, MediaType};
use x25519_dalek::{PublicKey, StaticSecret};
use crate::anti_replay::AntiReplayWindow;
use crate::nonce::{self, Direction}; use crate::nonce::{self, Direction};
use crate::rekey::RekeyManager; use crate::rekey::RekeyManager;
@@ -28,6 +31,10 @@ pub struct ChaChaSession {
pending_rekey_secret: Option<StaticSecret>, pending_rekey_secret: Option<StaticSecret>,
/// Short Authentication String (4-digit code for verbal verification). /// Short Authentication String (4-digit code for verbal verification).
sas_code: Option<u32>, sas_code: Option<u32>,
/// Per-stream anti-replay windows, keyed by (stream_id, media_type).
anti_replay: HashMap<(u8, MediaType), AntiReplayWindow>,
/// Last timestamp seen in encrypt() — used to assert monotonicity across rekeys.
last_encrypt_timestamp: Option<u32>,
} }
impl ChaChaSession { impl ChaChaSession {
@@ -49,6 +56,8 @@ impl ChaChaSession {
rekey_mgr: RekeyManager::new(shared_secret), rekey_mgr: RekeyManager::new(shared_secret),
pending_rekey_secret: None, pending_rekey_secret: None,
sas_code: None, sas_code: None,
anti_replay: HashMap::new(),
last_encrypt_timestamp: None,
} }
} }
@@ -67,6 +76,27 @@ impl ChaChaSession {
} }
} }
/// Parse a v2 `MediaHeader` from raw bytes.
/// Returns `None` if the buffer is too short or not a valid v2 header.
fn parse_header(header_bytes: &[u8]) -> Option<MediaHeader> {
if header_bytes.len() < MediaHeader::WIRE_SIZE {
return None;
}
let mut cursor = std::io::Cursor::new(header_bytes);
MediaHeader::read_from(&mut cursor)
}
/// Return the default anti-replay window size for a given media type.
fn default_window_for_media_type(media_type: MediaType) -> AntiReplayWindow {
let size = match media_type {
MediaType::Audio => 64,
MediaType::Video => 1024,
MediaType::Data => 256,
MediaType::Control => 32,
};
AntiReplayWindow::with_window(size)
}
impl CryptoSession for ChaChaSession { impl CryptoSession for ChaChaSession {
fn encrypt( fn encrypt(
&mut self, &mut self,
@@ -74,10 +104,14 @@ impl CryptoSession for ChaChaSession {
plaintext: &[u8], plaintext: &[u8],
out: &mut Vec<u8>, out: &mut Vec<u8>,
) -> Result<(), CryptoError> { ) -> Result<(), CryptoError> {
let nonce_bytes = nonce::build_nonce(&self.session_id, self.send_seq, Direction::Send); // Derive nonce from the wire-level seq in the header, not from an
// internal counter. This ensures the receiver can reconstruct the
// same nonce using the header it receives, regardless of delivery order.
let header = parse_header(header_bytes)
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
// Encrypt with AAD
use chacha20poly1305::aead::Payload; use chacha20poly1305::aead::Payload;
let payload = Payload { let payload = Payload {
msg: plaintext, msg: plaintext,
@@ -90,7 +124,19 @@ impl CryptoSession for ChaChaSession {
.map_err(|_| CryptoError::Internal("encryption failed".into()))?; .map_err(|_| CryptoError::Internal("encryption failed".into()))?;
out.extend_from_slice(&ciphertext); out.extend_from_slice(&ciphertext);
self.send_seq = self.send_seq.wrapping_add(1); self.send_seq = self.send_seq.wrapping_add(1); // packet counter for rekey trigger only
// M5: assert timestamp_ms is non-decreasing across calls (including post-rekey).
// Timestamps are u32 and wrap at 2^32 ms (~49 days); allow wrapping.
debug_assert!(
self.last_encrypt_timestamp
.map_or(true, |last| header.timestamp.wrapping_sub(last) < u32::MAX / 2),
"encrypt: timestamp must not decrease (last={:?}, now={})",
self.last_encrypt_timestamp,
header.timestamp,
);
self.last_encrypt_timestamp = Some(header.timestamp);
Ok(()) Ok(())
} }
@@ -100,9 +146,14 @@ impl CryptoSession for ChaChaSession {
ciphertext: &[u8], ciphertext: &[u8],
out: &mut Vec<u8>, out: &mut Vec<u8>,
) -> Result<(), CryptoError> { ) -> Result<(), CryptoError> {
// Use Direction::Send to match the sender's nonce construction. // Parse header before decryption — needed for nonce derivation.
// The recv_seq counter tracks which packet from the peer we're decrypting. // Using header.seq (not recv_seq) means the nonce is always derived
let nonce_bytes = nonce::build_nonce(&self.session_id, self.recv_seq, Direction::Send); // from the same wire field as the sender, surviving out-of-order delivery.
// A recv_seq counter diverges from the sender's send_seq on any reorder,
// causing every subsequent decryption to fail for the rest of the session.
let header = parse_header(header_bytes)
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
use chacha20poly1305::aead::Payload; use chacha20poly1305::aead::Payload;
@@ -116,8 +167,21 @@ impl CryptoSession for ChaChaSession {
.decrypt(nonce, payload) .decrypt(nonce, payload)
.map_err(|_| CryptoError::DecryptionFailed)?; .map_err(|_| CryptoError::DecryptionFailed)?;
let plaintext_len = plaintext.len();
out.extend_from_slice(&plaintext); out.extend_from_slice(&plaintext);
self.recv_seq = self.recv_seq.wrapping_add(1); self.recv_seq = self.recv_seq.wrapping_add(1); // packet counter for rekey trigger only
// Anti-replay check: header already parsed above.
let window = self
.anti_replay
.entry((header.stream_id, header.media_type))
.or_insert_with(|| default_window_for_media_type(header.media_type));
if let Err(e) = window.check_and_update(header.seq) {
// Roll back the plaintext we just appended.
out.truncate(out.len() - plaintext_len);
return Err(e);
}
Ok(()) Ok(())
} }
@@ -135,10 +199,14 @@ impl CryptoSession for ChaChaSession {
.ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?; .ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?;
let total_packets = self.send_seq as u64 + self.recv_seq as u64; let total_packets = self.send_seq as u64 + self.recv_seq as u64;
let new_key = self.rekey_mgr.perform_rekey(peer_ephemeral_pub, secret, total_packets); let new_key = self
.rekey_mgr
.perform_rekey(peer_ephemeral_pub, secret, total_packets);
self.install_key(new_key); self.install_key(new_key);
// Reset sequence counters after rekey for nonce uniqueness // Reset sequence counters after rekey for nonce uniqueness.
// last_encrypt_timestamp is intentionally NOT reset — spec requires
// timestamp_ms to be monotonic across rekeys.
self.send_seq = 0; self.send_seq = 0;
self.recv_seq = 0; self.recv_seq = 0;
@@ -153,24 +221,42 @@ impl CryptoSession for ChaChaSession {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use wzp_proto::{CodecId, MediaType};
fn make_session_pair() -> (ChaChaSession, ChaChaSession) { fn make_session_pair() -> (ChaChaSession, ChaChaSession) {
let key = [0x42u8; 32]; let key = [0x42u8; 32];
(ChaChaSession::new(key), ChaChaSession::new(key)) (ChaChaSession::new(key), ChaChaSession::new(key))
} }
/// Build a minimal valid v2 MediaHeader serialised to bytes.
fn make_header_bytes(seq: u32) -> Vec<u8> {
let header = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq,
timestamp: seq.wrapping_mul(20),
fec_block: 0,
};
let mut bytes = Vec::new();
header.write_to(&mut bytes);
bytes
}
#[test] #[test]
fn encrypt_decrypt_roundtrip() { fn encrypt_decrypt_roundtrip() {
let (mut alice, mut bob) = make_session_pair(); let (mut alice, mut bob) = make_session_pair();
let header = b"test-header"; let header = make_header_bytes(0);
let plaintext = b"hello warzone"; let plaintext = b"hello warzone";
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
alice.encrypt(header, plaintext, &mut ciphertext).unwrap(); alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
// Bob decrypts (his recv matches Alice's send)
let mut decrypted = Vec::new(); let mut decrypted = Vec::new();
bob.decrypt(header, &ciphertext, &mut decrypted).unwrap(); bob.decrypt(&header, &ciphertext, &mut decrypted).unwrap();
assert_eq!(&decrypted, plaintext); assert_eq!(&decrypted, plaintext);
} }
@@ -178,14 +264,18 @@ mod tests {
#[test] #[test]
fn decrypt_wrong_aad_fails() { fn decrypt_wrong_aad_fails() {
let (mut alice, mut bob) = make_session_pair(); let (mut alice, mut bob) = make_session_pair();
let header = b"correct-header"; let correct_header = make_header_bytes(0);
// Different seq → different nonce AND different AAD bytes: decryption must fail.
let wrong_header = make_header_bytes(1);
let plaintext = b"secret data"; let plaintext = b"secret data";
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
alice.encrypt(header, plaintext, &mut ciphertext).unwrap(); alice
.encrypt(&correct_header, plaintext, &mut ciphertext)
.unwrap();
let mut decrypted = Vec::new(); let mut decrypted = Vec::new();
let result = bob.decrypt(b"wrong-header", &ciphertext, &mut decrypted); let result = bob.decrypt(&wrong_header, &ciphertext, &mut decrypted);
assert!(result.is_err()); assert!(result.is_err());
} }
@@ -194,29 +284,29 @@ mod tests {
let mut alice = ChaChaSession::new([0xAA; 32]); let mut alice = ChaChaSession::new([0xAA; 32]);
let mut eve = ChaChaSession::new([0xBB; 32]); let mut eve = ChaChaSession::new([0xBB; 32]);
let header = b"hdr"; let header = make_header_bytes(0);
let plaintext = b"secret"; let plaintext = b"secret";
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
alice.encrypt(header, plaintext, &mut ciphertext).unwrap(); alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
let mut decrypted = Vec::new(); let mut decrypted = Vec::new();
let result = eve.decrypt(header, &ciphertext, &mut decrypted); let result = eve.decrypt(&header, &ciphertext, &mut decrypted);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn multiple_packets_roundtrip() { fn multiple_packets_roundtrip() {
let (mut alice, mut bob) = make_session_pair(); let (mut alice, mut bob) = make_session_pair();
let header = b"hdr";
for i in 0..100 { for i in 0..100u32 {
let header = make_header_bytes(i);
let msg = format!("message {}", i); let msg = format!("message {}", i);
let mut ct = Vec::new(); let mut ct = Vec::new();
alice.encrypt(header, msg.as_bytes(), &mut ct).unwrap(); alice.encrypt(&header, msg.as_bytes(), &mut ct).unwrap();
let mut pt = Vec::new(); let mut pt = Vec::new();
bob.decrypt(header, &ct, &mut pt).unwrap(); bob.decrypt(&header, &ct, &mut pt).unwrap();
assert_eq!(pt, msg.as_bytes()); assert_eq!(pt, msg.as_bytes());
} }
} }
@@ -235,4 +325,140 @@ mod tests {
// Session is now rekeyed - counters reset // Session is now rekeyed - counters reset
assert_eq!(alice.send_seq, 0); assert_eq!(alice.send_seq, 0);
} }
#[test]
fn decrypt_survives_out_of_order_delivery() {
// Regression test for nonce derivation using recv_seq instead of
// MediaHeader.seq. If nonces are tied to a local counter, any reorder
// causes the counter to diverge from the sender's seq and every
// subsequent packet fails decryption permanently.
use wzp_proto::{CodecId, MediaType};
let key = [0x55u8; 32];
let mut alice = ChaChaSession::new(key);
let mut bob = ChaChaSession::new(key);
let plaintext = b"audio payload";
// Encrypt 5 packets in order (seqs 10, 11, 12, 13, 14).
let seqs = [10u32, 11, 12, 13, 14];
let mut ciphertexts: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
for &seq in &seqs {
let header = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq,
timestamp: seq * 20,
fec_block: 0,
};
let mut header_bytes = Vec::new();
header.write_to(&mut header_bytes);
let mut ct = Vec::new();
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
ciphertexts.push((header_bytes, ct));
}
// Bob receives them out of order: 0, 2, 1, 4, 3
let delivery_order = [0usize, 2, 1, 4, 3];
for &idx in &delivery_order {
let (ref hdr, ref ct) = ciphertexts[idx];
let mut pt = Vec::new();
let result = bob.decrypt(hdr, ct, &mut pt);
assert!(
result.is_ok(),
"out-of-order packet (original idx={idx}, seq={}) must decrypt successfully",
seqs[idx]
);
assert_eq!(&pt, plaintext);
}
}
#[test]
fn per_stream_anti_replay_rejects_duplicate() {
use wzp_proto::{CodecId, MediaType};
let (mut alice, mut bob) = make_session_pair();
let header = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 10,
seq: 42,
timestamp: 1000,
fec_block: 0,
};
let mut header_bytes = Vec::new();
header.write_to(&mut header_bytes);
let plaintext = b"audio frame";
// First packet decrypts successfully
let mut ct = Vec::new();
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
let mut pt = Vec::new();
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
assert_eq!(&pt, plaintext);
// Exact duplicate is rejected by anti-replay
let mut pt2 = Vec::new();
let result = bob.decrypt(&header_bytes, &ct, &mut pt2);
assert!(
result.is_err(),
"duplicate packet with same seq must be rejected"
);
assert!(pt2.is_empty(), "plaintext must be rolled back on replay");
}
#[test]
fn per_stream_anti_replay_video_burst_200_with_reorder() {
use wzp_proto::{CodecId, MediaType};
let (mut alice, mut bob) = make_session_pair();
let header = MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Video,
codec_id: CodecId::Opus24k,
stream_id: 1,
fec_ratio: 10,
seq: 0,
timestamp: 0,
fec_block: 0,
};
let plaintext = b"video frame";
// Send 200 packets in order
for i in 0..200 {
let mut h = header;
h.seq = i;
let mut header_bytes = Vec::new();
h.write_to(&mut header_bytes);
let mut ct = Vec::new();
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
let mut pt = Vec::new();
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
}
// Re-send packet 50 — should be rejected as replay
let mut h = header;
h.seq = 50;
let mut header_bytes = Vec::new();
h.write_to(&mut header_bytes);
let mut ct = Vec::new();
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
let mut pt = Vec::new();
let result = bob.decrypt(&header_bytes, &ct, &mut pt);
assert!(result.is_err(), "reordered duplicate must be rejected");
}
} }

View File

@@ -6,7 +6,7 @@
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract //! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
//! 4. Mnemonic: BIP39 interop between both implementations //! 4. Mnemonic: BIP39 interop between both implementations
use wzp_proto::KeyExchange; use wzp_proto::{KeyExchange, default_signal_version};
// ─── Identity Compatibility (WZP-FC-8) ────────────────────────────────────── // ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
@@ -52,7 +52,10 @@ fn wzp_identity_module_matches_featherchat() {
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes()); assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes()); assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0); assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
assert_eq!(wzp_pub.fingerprint.to_string(), fc_pub.fingerprint.to_string()); assert_eq!(
wzp_pub.fingerprint.to_string(),
fc_pub.fingerprint.to_string()
);
} }
#[test] #[test]
@@ -111,11 +114,15 @@ fn mnemonic_strings_identical() {
fn wzp_signal_serializes_into_fc_callsignal_payload() { fn wzp_signal_serializes_into_fc_callsignal_payload() {
// WZP creates a CallOffer SignalMessage // WZP creates a CallOffer SignalMessage
let offer = wzp_proto::SignalMessage::CallOffer { let offer = wzp_proto::SignalMessage::CallOffer {
version: default_signal_version(),
identity_pub: [1u8; 32], identity_pub: [1u8; 32],
ephemeral_pub: [2u8; 32], ephemeral_pub: [2u8; 32],
signature: vec![3u8; 64], signature: vec![3u8; 64],
supported_profiles: vec![wzp_proto::QualityProfile::GOOD], supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None, alias: None,
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}; };
// Encode as featherChat CallSignal payload // Encode as featherChat CallSignal payload
@@ -148,16 +155,25 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
// And deserializes back // And deserializes back
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap(); let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
if let warzone_protocol::message::WireMessage::CallSignal { if let warzone_protocol::message::WireMessage::CallSignal {
id, payload: p, signal_type, .. id,
payload: p,
signal_type,
..
} = decoded } = decoded
{ {
assert_eq!(id, "call-123"); assert_eq!(id, "call-123");
assert!(matches!(signal_type, warzone_protocol::message::CallSignalType::Offer)); assert!(matches!(
signal_type,
warzone_protocol::message::CallSignalType::Offer
));
// Decode the WZP payload back // Decode the WZP payload back
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap(); let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433"); assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
assert!(matches!(wzp_payload.signal, wzp_proto::SignalMessage::CallOffer { .. })); assert!(matches!(
wzp_payload.signal,
wzp_proto::SignalMessage::CallOffer { .. }
));
} else { } else {
panic!("expected CallSignal"); panic!("expected CallSignal");
} }
@@ -166,10 +182,12 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
#[test] #[test]
fn wzp_answer_round_trips_through_fc_callsignal() { fn wzp_answer_round_trips_through_fc_callsignal() {
let answer = wzp_proto::SignalMessage::CallAnswer { let answer = wzp_proto::SignalMessage::CallAnswer {
version: default_signal_version(),
identity_pub: [10u8; 32], identity_pub: [10u8; 32],
ephemeral_pub: [20u8; 32], ephemeral_pub: [20u8; 32],
signature: vec![30u8; 64], signature: vec![30u8; 64],
chosen_profile: wzp_proto::QualityProfile::DEGRADED, chosen_profile: wzp_proto::QualityProfile::DEGRADED,
video_codec: None,
}; };
let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None); let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None);
@@ -198,13 +216,17 @@ fn wzp_answer_round_trips_through_fc_callsignal() {
#[test] #[test]
fn wzp_hangup_round_trips_through_fc_callsignal() { fn wzp_hangup_round_trips_through_fc_callsignal() {
let hangup = wzp_proto::SignalMessage::Hangup { let hangup = wzp_proto::SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}; };
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None); let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup); let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
assert!(matches!(signal_type, wzp_client::featherchat::CallSignalType::Hangup)); assert!(matches!(
signal_type,
wzp_client::featherchat::CallSignalType::Hangup
));
let fc_msg = warzone_protocol::message::WireMessage::CallSignal { let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
id: "call-789".to_string(), id: "call-789".to_string(),
@@ -219,7 +241,10 @@ fn wzp_hangup_round_trips_through_fc_callsignal() {
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded { if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap(); let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
assert!(matches!(wzp.signal, wzp_proto::SignalMessage::Hangup { .. })); assert!(matches!(
wzp.signal,
wzp_proto::SignalMessage::Hangup { .. }
));
} }
} }
@@ -252,8 +277,7 @@ fn auth_validate_response_matches_wzp_expectations() {
"eth_address": null "eth_address": null
}); });
let wzp_resp: wzp_relay::auth::ValidateResponse = let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
serde_json::from_value(fc_response).unwrap();
assert!(wzp_resp.valid); assert!(wzp_resp.valid);
assert_eq!( assert_eq!(
wzp_resp.fingerprint.unwrap(), wzp_resp.fingerprint.unwrap(),
@@ -265,8 +289,7 @@ fn auth_validate_response_matches_wzp_expectations() {
#[test] #[test]
fn auth_invalid_response_matches() { fn auth_invalid_response_matches() {
let fc_response = serde_json::json!({ "valid": false }); let fc_response = serde_json::json!({ "valid": false });
let wzp_resp: wzp_relay::auth::ValidateResponse = let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
serde_json::from_value(fc_response).unwrap();
assert!(!wzp_resp.valid); assert!(!wzp_resp.valid);
assert!(wzp_resp.fingerprint.is_none()); assert!(wzp_resp.fingerprint.is_none());
} }
@@ -280,28 +303,39 @@ fn all_signal_types_map_correctly() {
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![ let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
( (
wzp_proto::SignalMessage::CallOffer { wzp_proto::SignalMessage::CallOffer {
identity_pub: [0; 32], ephemeral_pub: [0; 32], version: default_signal_version(),
signature: vec![], supported_profiles: vec![], identity_pub: [0; 32],
ephemeral_pub: [0; 32],
signature: vec![],
supported_profiles: vec![],
alias: None, alias: None,
protocol_version: 2,
supported_versions: vec![2],
video_codecs: vec![],
}, },
"Offer", "Offer",
), ),
( (
wzp_proto::SignalMessage::CallAnswer { wzp_proto::SignalMessage::CallAnswer {
identity_pub: [0; 32], ephemeral_pub: [0; 32], version: default_signal_version(),
identity_pub: [0; 32],
ephemeral_pub: [0; 32],
signature: vec![], signature: vec![],
chosen_profile: wzp_proto::QualityProfile::GOOD, chosen_profile: wzp_proto::QualityProfile::GOOD,
video_codec: None,
}, },
"Answer", "Answer",
), ),
( (
wzp_proto::SignalMessage::IceCandidate { wzp_proto::SignalMessage::IceCandidate {
version: default_signal_version(),
candidate: "candidate:1".to_string(), candidate: "candidate:1".to_string(),
}, },
"IceCandidate", "IceCandidate",
), ),
( (
wzp_proto::SignalMessage::Hangup { wzp_proto::SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}, },
@@ -312,7 +346,10 @@ fn all_signal_types_map_correctly() {
for (signal, expected_name) in cases { for (signal, expected_name) in cases {
let ct = signal_to_call_type(&signal); let ct = signal_to_call_type(&signal);
let name = format!("{ct:?}"); let name = format!("{ct:?}");
assert_eq!(name, expected_name, "signal type mapping for {expected_name}"); assert_eq!(
name, expected_name,
"signal type mapping for {expected_name}"
);
} }
} }
@@ -426,8 +463,7 @@ fn auth_response_with_eth_address() {
"alias": "vitalik", "alias": "vitalik",
"eth_address": "0x1234567890abcdef1234567890abcdef12345678" "eth_address": "0x1234567890abcdef1234567890abcdef12345678"
}); });
let resp: wzp_relay::auth::ValidateResponse = let resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_eth).unwrap();
serde_json::from_value(with_eth).unwrap();
assert!(resp.valid); assert!(resp.valid);
assert_eq!( assert_eq!(
resp.fingerprint.unwrap(), resp.fingerprint.unwrap(),
@@ -442,8 +478,7 @@ fn auth_response_with_eth_address() {
"alias": "anon", "alias": "anon",
"eth_address": null "eth_address": null
}); });
let resp2: wzp_relay::auth::ValidateResponse = let resp2: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_null_eth).unwrap();
serde_json::from_value(with_null_eth).unwrap();
assert!(resp2.valid); assert!(resp2.valid);
assert_eq!( assert_eq!(
resp2.fingerprint.unwrap(), resp2.fingerprint.unwrap(),
@@ -454,15 +489,15 @@ fn auth_response_with_eth_address() {
let without_eth = serde_json::json!({ let without_eth = serde_json::json!({
"valid": false "valid": false
}); });
let resp3: wzp_relay::auth::ValidateResponse = let resp3: wzp_relay::auth::ValidateResponse = serde_json::from_value(without_eth).unwrap();
serde_json::from_value(without_eth).unwrap();
assert!(!resp3.valid); assert!(!resp3.valid);
} }
/// WZP-S-7: SignalMessage::AuthToken { token } exists and round-trips via serde. /// WZP-S-7: SignalMessage::AuthToken { version: default_signal_version(), token } exists and round-trips via serde.
#[test] #[test]
fn wzp_proto_has_auth_token_variant() { fn wzp_proto_has_auth_token_variant() {
let msg = wzp_proto::SignalMessage::AuthToken { let msg = wzp_proto::SignalMessage::AuthToken {
version: default_signal_version(),
token: "fc-bearer-token-xyz".to_string(), token: "fc-bearer-token-xyz".to_string(),
}; };
@@ -473,7 +508,7 @@ fn wzp_proto_has_auth_token_variant() {
// Deserialize back // Deserialize back
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap(); let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
if let wzp_proto::SignalMessage::AuthToken { token } = decoded { if let wzp_proto::SignalMessage::AuthToken { token, .. } = decoded {
assert_eq!(token, "fc-bearer-token-xyz"); assert_eq!(token, "fc-bearer-token-xyz");
} else { } else {
panic!("expected AuthToken variant, got: {decoded:?}"); panic!("expected AuthToken variant, got: {decoded:?}");
@@ -496,7 +531,11 @@ fn all_fc_call_signal_types_representable() {
(CallSignalType::Busy, "Busy"), (CallSignalType::Busy, "Busy"),
]; ];
assert_eq!(variants.len(), 7, "featherChat defines exactly 7 call signal types"); assert_eq!(
variants.len(),
7,
"featherChat defines exactly 7 call signal types"
);
for (variant, expected_name) in &variants { for (variant, expected_name) in &variants {
let name = format!("{variant:?}"); let name = format!("{variant:?}");
@@ -550,10 +589,7 @@ fn hash_room_name_used_as_sni_is_valid() {
#[test] #[test]
fn wzp_proto_cargo_toml_is_standalone() { fn wzp_proto_cargo_toml_is_standalone() {
// Try both paths (run from workspace root or from crate directory) // Try both paths (run from workspace root or from crate directory)
let candidates = [ let candidates = ["crates/wzp-proto/Cargo.toml", "../wzp-proto/Cargo.toml"];
"crates/wzp-proto/Cargo.toml",
"../wzp-proto/Cargo.toml",
];
let contents = candidates let contents = candidates
.iter() .iter()

View File

@@ -13,11 +13,17 @@ pub struct AdaptiveFec {
pub repair_ratio: f32, pub repair_ratio: f32,
/// Symbol size in bytes. /// Symbol size in bytes.
pub symbol_size: u16, pub symbol_size: u16,
/// Repair ratio to use when the block contains a keyframe.
/// Default 0.5 (50% overhead) — keyframes are critical and worth
/// the extra bandwidth.
pub keyframe_repair_ratio: f32,
} }
impl AdaptiveFec { impl AdaptiveFec {
/// Default symbol size for adaptive configuration. /// Default symbol size for adaptive configuration.
const DEFAULT_SYMBOL_SIZE: u16 = 256; const DEFAULT_SYMBOL_SIZE: u16 = 256;
/// Default keyframe repair ratio (PRD-video-v1 T4.5).
const DEFAULT_KEYFRAME_REPAIR_RATIO: f32 = 0.5;
/// Create an adaptive FEC configuration from a quality profile. /// Create an adaptive FEC configuration from a quality profile.
/// ///
@@ -30,12 +36,15 @@ impl AdaptiveFec {
frames_per_block: profile.frames_per_block as usize, frames_per_block: profile.frames_per_block as usize,
repair_ratio: profile.fec_ratio, repair_ratio: profile.fec_ratio,
symbol_size: Self::DEFAULT_SYMBOL_SIZE, symbol_size: Self::DEFAULT_SYMBOL_SIZE,
keyframe_repair_ratio: Self::DEFAULT_KEYFRAME_REPAIR_RATIO,
} }
} }
/// Build a configured FEC encoder from this adaptive configuration. /// Build a configured FEC encoder from this adaptive configuration.
pub fn build_encoder(&self) -> RaptorQFecEncoder { pub fn build_encoder(&self) -> RaptorQFecEncoder {
RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size) let mut enc = RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size);
enc.set_keyframe_ratio(self.keyframe_repair_ratio);
enc
} }
/// Get the repair ratio for use with `FecEncoder::generate_repair()`. /// Get the repair ratio for use with `FecEncoder::generate_repair()`.
@@ -59,6 +68,7 @@ mod tests {
let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD); let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD);
assert_eq!(cfg.frames_per_block, 5); assert_eq!(cfg.frames_per_block, 5);
assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON); assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON);
assert!((cfg.keyframe_repair_ratio - 0.5).abs() < f32::EPSILON);
} }
#[test] #[test]

View File

@@ -29,9 +29,9 @@ pub enum DecoderBlockState {
/// Manages encoder-side block tracking. /// Manages encoder-side block tracking.
pub struct EncoderBlockManager { pub struct EncoderBlockManager {
/// Current block ID being built. /// Current block ID being built.
current_id: u8, current_id: u16,
/// State of known blocks. /// State of known blocks.
blocks: HashMap<u8, EncoderBlockState>, blocks: HashMap<u16, EncoderBlockState>,
} }
impl EncoderBlockManager { impl EncoderBlockManager {
@@ -45,7 +45,7 @@ impl EncoderBlockManager {
} }
/// Get the next block ID (advances the current building block). /// Get the next block ID (advances the current building block).
pub fn next_block_id(&mut self) -> u8 { pub fn next_block_id(&mut self) -> u16 {
let old = self.current_id; let old = self.current_id;
// Mark old block as pending. // Mark old block as pending.
self.blocks.insert(old, EncoderBlockState::Pending); self.blocks.insert(old, EncoderBlockState::Pending);
@@ -57,23 +57,23 @@ impl EncoderBlockManager {
} }
/// Current block ID being built. /// Current block ID being built.
pub fn current_id(&self) -> u8 { pub fn current_id(&self) -> u16 {
self.current_id self.current_id
} }
/// Mark a block as fully sent. /// Mark a block as fully sent.
pub fn mark_sent(&mut self, block_id: u8) { pub fn mark_sent(&mut self, block_id: u16) {
self.blocks.insert(block_id, EncoderBlockState::Sent); self.blocks.insert(block_id, EncoderBlockState::Sent);
} }
/// Mark a block as acknowledged by the peer. /// Mark a block as acknowledged by the peer.
pub fn mark_acknowledged(&mut self, block_id: u8) { pub fn mark_acknowledged(&mut self, block_id: u16) {
self.blocks self.blocks
.insert(block_id, EncoderBlockState::Acknowledged); .insert(block_id, EncoderBlockState::Acknowledged);
} }
/// Get the state of a block. /// Get the state of a block.
pub fn state(&self, block_id: u8) -> Option<EncoderBlockState> { pub fn state(&self, block_id: u16) -> Option<EncoderBlockState> {
self.blocks.get(&block_id).copied() self.blocks.get(&block_id).copied()
} }
@@ -93,9 +93,9 @@ impl Default for EncoderBlockManager {
/// Manages decoder-side block tracking. /// Manages decoder-side block tracking.
pub struct DecoderBlockManager { pub struct DecoderBlockManager {
/// State of known blocks. /// State of known blocks.
blocks: HashMap<u8, DecoderBlockState>, blocks: HashMap<u16, DecoderBlockState>,
/// Set of completed block IDs. /// Set of completed block IDs.
completed: HashSet<u8>, completed: HashSet<u16>,
} }
impl DecoderBlockManager { impl DecoderBlockManager {
@@ -107,43 +107,43 @@ impl DecoderBlockManager {
} }
/// Register that we are receiving symbols for a block. /// Register that we are receiving symbols for a block.
pub fn touch(&mut self, block_id: u8) { pub fn touch(&mut self, block_id: u16) {
self.blocks self.blocks
.entry(block_id) .entry(block_id)
.or_insert(DecoderBlockState::Assembling); .or_insert(DecoderBlockState::Assembling);
} }
/// Mark a block as successfully decoded. /// Mark a block as successfully decoded.
pub fn mark_complete(&mut self, block_id: u8) { pub fn mark_complete(&mut self, block_id: u16) {
self.blocks.insert(block_id, DecoderBlockState::Complete); self.blocks.insert(block_id, DecoderBlockState::Complete);
self.completed.insert(block_id); self.completed.insert(block_id);
} }
/// Mark a block as expired. /// Mark a block as expired.
pub fn mark_expired(&mut self, block_id: u8) { pub fn mark_expired(&mut self, block_id: u16) {
self.blocks.insert(block_id, DecoderBlockState::Expired); self.blocks.insert(block_id, DecoderBlockState::Expired);
self.completed.remove(&block_id); self.completed.remove(&block_id);
} }
/// Check if a block has been fully decoded. /// Check if a block has been fully decoded.
pub fn is_block_complete(&self, block_id: u8) -> bool { pub fn is_block_complete(&self, block_id: u16) -> bool {
self.completed.contains(&block_id) self.completed.contains(&block_id)
} }
/// Get the state of a block. /// Get the state of a block.
pub fn state(&self, block_id: u8) -> Option<DecoderBlockState> { pub fn state(&self, block_id: u16) -> Option<DecoderBlockState> {
self.blocks.get(&block_id).copied() self.blocks.get(&block_id).copied()
} }
/// Expire all blocks older than the given block_id (using wrapping distance). /// Expire all blocks older than the given block_id (using wrapping distance).
pub fn expire_before(&mut self, block_id: u8) { pub fn expire_before(&mut self, block_id: u16) {
let to_expire: Vec<u8> = self let to_expire: Vec<u16> = self
.blocks .blocks
.keys() .keys()
.copied() .copied()
.filter(|&id| { .filter(|&id| {
let distance = block_id.wrapping_sub(id); let distance = block_id.wrapping_sub(id);
distance > 0 && distance <= 128 distance > 0 && distance <= 32768
}) })
.collect(); .collect();
@@ -207,7 +207,7 @@ mod tests {
#[test] #[test]
fn decoder_expire_before() { fn decoder_expire_before() {
let mut mgr = DecoderBlockManager::new(); let mut mgr = DecoderBlockManager::new();
for i in 0..5u8 { for i in 0..5u16 {
mgr.touch(i); mgr.touch(i);
} }
mgr.mark_complete(1); mgr.mark_complete(1);
@@ -231,11 +231,11 @@ mod tests {
#[test] #[test]
fn next_block_id_wraps() { fn next_block_id_wraps() {
let mut mgr = EncoderBlockManager::new(); let mut mgr = EncoderBlockManager::new();
// Start at 0, advance to 255 then wrap // Start at 0, advance to u16::MAX then wrap
for _ in 0..255 { for _ in 0..65535 {
mgr.next_block_id(); mgr.next_block_id();
} }
assert_eq!(mgr.current_id(), 255); assert_eq!(mgr.current_id(), u16::MAX);
let next = mgr.next_block_id(); let next = mgr.next_block_id();
assert_eq!(next, 0); assert_eq!(next, 0);
} }

View File

@@ -4,8 +4,8 @@ use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder}; use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder};
use wzp_proto::error::FecError;
use wzp_proto::FecDecoder; use wzp_proto::FecDecoder;
use wzp_proto::error::FecError;
/// Length prefix size (u16 little-endian), must match encoder. /// Length prefix size (u16 little-endian), must match encoder.
const LEN_PREFIX: usize = 2; const LEN_PREFIX: usize = 2;
@@ -32,7 +32,7 @@ struct BlockState {
/// RaptorQ-based FEC decoder that handles multiple concurrent blocks. /// RaptorQ-based FEC decoder that handles multiple concurrent blocks.
pub struct RaptorQFecDecoder { pub struct RaptorQFecDecoder {
/// Per-block decoder state, keyed by block_id. /// Per-block decoder state, keyed by block_id.
blocks: HashMap<u8, BlockState>, blocks: HashMap<u16, BlockState>,
/// Symbol size (must match encoder). /// Symbol size (must match encoder).
symbol_size: u16, symbol_size: u16,
/// Number of source symbols per block (from encoder config). /// Number of source symbols per block (from encoder config).
@@ -57,7 +57,7 @@ impl RaptorQFecDecoder {
Self::new(frames_per_block, 256) Self::new(frames_per_block, 256)
} }
fn get_or_create_block(&mut self, block_id: u8) -> &mut BlockState { fn get_or_create_block(&mut self, block_id: u16) -> &mut BlockState {
self.blocks.entry(block_id).or_insert_with(|| BlockState { self.blocks.entry(block_id).or_insert_with(|| BlockState {
num_source_symbols: Some(self.frames_per_block), num_source_symbols: Some(self.frames_per_block),
packets: Vec::new(), packets: Vec::new(),
@@ -72,8 +72,8 @@ impl RaptorQFecDecoder {
impl FecDecoder for RaptorQFecDecoder { impl FecDecoder for RaptorQFecDecoder {
fn add_symbol( fn add_symbol(
&mut self, &mut self,
block_id: u8, block_id: u16,
symbol_index: u8, symbol_index: u16,
_is_repair: bool, _is_repair: bool,
data: &[u8], data: &[u8],
) -> Result<(), FecError> { ) -> Result<(), FecError> {
@@ -104,13 +104,13 @@ impl FecDecoder for RaptorQFecDecoder {
padded[..len].copy_from_slice(&data[..len]); padded[..len].copy_from_slice(&data[..len]);
let esi = symbol_index as u32; let esi = symbol_index as u32;
let packet = EncodingPacket::new(PayloadId::new(block_id, esi), padded); let packet = EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, esi), padded);
block.packets.push(packet); block.packets.push(packet);
Ok(()) Ok(())
} }
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError> { fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError> {
let frames_per_block = self.frames_per_block; let frames_per_block = self.frames_per_block;
let block = match self.blocks.get_mut(&block_id) { let block = match self.blocks.get_mut(&block_id) {
Some(b) => b, Some(b) => b,
@@ -125,7 +125,7 @@ impl FecDecoder for RaptorQFecDecoder {
let block_length = (num_source as u64) * (block.symbol_size as u64); let block_length = (num_source as u64) * (block.symbol_size as u64);
let config = ObjectTransmissionInformation::with_defaults(block_length, block.symbol_size); let config = ObjectTransmissionInformation::with_defaults(block_length, block.symbol_size);
let mut decoder = SourceBlockDecoder::new(block_id, &config, block_length); let mut decoder = SourceBlockDecoder::new((block_id & 0xFF) as u8, &config, block_length);
let decoded = decoder.decode(block.packets.clone()); let decoded = decoder.decode(block.packets.clone());
@@ -140,10 +140,7 @@ impl FecDecoder for RaptorQFecDecoder {
frames.push(Vec::new()); frames.push(Vec::new());
continue; continue;
} }
let payload_len = u16::from_le_bytes([ let payload_len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
data[offset],
data[offset + 1],
]) as usize;
let payload_start = offset + LEN_PREFIX; let payload_start = offset + LEN_PREFIX;
let payload_end = (payload_start + payload_len).min(data.len()); let payload_end = (payload_start + payload_len).min(data.len());
frames.push(data[payload_start..payload_end].to_vec()); frames.push(data[payload_start..payload_end].to_vec());
@@ -159,15 +156,15 @@ impl FecDecoder for RaptorQFecDecoder {
} }
} }
fn expire_before(&mut self, block_id: u8) { fn expire_before(&mut self, block_id: u16) {
// Remove blocks with IDs "older" than block_id. // Remove blocks with IDs "older" than block_id.
// With wrapping u8 IDs, we consider a block old if its distance // With wrapping u16 IDs, we consider a block old if its distance
// (in the forward direction) to block_id is > 128. // (in the forward direction) to block_id is > 32768.
self.blocks.retain(|&id, _| { self.blocks.retain(|&id, _| {
let distance = block_id.wrapping_sub(id); let distance = block_id.wrapping_sub(id);
// If distance is 0 or > 128, the block is current or "ahead" — keep it. // If distance is 0 or > 32768, the block is current or "ahead" — keep it.
// If distance is 1..=128, the block is behind — remove it. // If distance is 1..=32768, the block is behind — remove it.
distance == 0 || distance > 128 distance == 0 || distance > 32768
}); });
} }
} }
@@ -198,9 +195,7 @@ mod tests {
// Feed all source symbols (using the length-prefixed padded data). // Feed all source symbols (using the length-prefixed padded data).
for (i, pkt) in source_pkts.iter().enumerate() { for (i, pkt) in source_pkts.iter().enumerate() {
decoder decoder.add_symbol(0, i as u16, false, pkt.data()).unwrap();
.add_symbol(0, i as u8, false, pkt.data())
.unwrap();
} }
let result = decoder.try_decode(0).unwrap(); let result = decoder.try_decode(0).unwrap();
@@ -233,7 +228,11 @@ mod tests {
let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1); let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1);
let mut dec = SourceBlockDecoder::new(0, &config, block_len); let mut dec = SourceBlockDecoder::new(0, &config, block_len);
let decoded = dec.decode(all); let decoded = dec.decode(all);
assert!(decoded.is_some(), "Should recover with {:.0}% loss", drop_fraction * 100.0); assert!(
decoded.is_some(),
"Should recover with {:.0}% loss",
drop_fraction * 100.0
);
let data = decoded.unwrap(); let data = decoded.unwrap();
let ss = SYMBOL_SIZE as usize; let ss = SYMBOL_SIZE as usize;
@@ -245,22 +244,28 @@ mod tests {
} }
#[test] #[test]
fn decode_with_30pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3); } fn decode_with_30pct_loss() {
run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3);
}
#[test] #[test]
fn decode_with_50pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5); } fn decode_with_50pct_loss() {
run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5);
}
#[test] #[test]
fn decode_with_70pct_source_loss_heavy_repair() { run_loss_test(8, 2.0, 0.5); } fn decode_with_70pct_source_loss_heavy_repair() {
run_loss_test(8, 2.0, 0.5);
}
#[test] #[test]
fn expire_removes_old_blocks() { fn expire_removes_old_blocks() {
let mut decoder = RaptorQFecDecoder::new(FRAMES_PER_BLOCK, SYMBOL_SIZE); let mut decoder = RaptorQFecDecoder::new(FRAMES_PER_BLOCK, SYMBOL_SIZE);
// Add symbols to blocks 0, 1, 2 // Add symbols to blocks 0, 1, 2
for block_id in 0..3u8 { for block_id in 0..3u16 {
decoder decoder
.add_symbol(block_id, 0, false, &[block_id; 50]) .add_symbol(block_id, 0, false, &[block_id as u8; 50])
.unwrap(); .unwrap();
} }
@@ -288,10 +293,10 @@ mod tests {
// Interleave symbols from block 0 and block 1 // Interleave symbols from block 0 and block 1
for i in 0..FRAMES_PER_BLOCK { for i in 0..FRAMES_PER_BLOCK {
decoder decoder
.add_symbol(0, i as u8, false, pkts_a[i].data()) .add_symbol(0, i as u16, false, pkts_a[i].data())
.unwrap(); .unwrap();
decoder decoder
.add_symbol(1, i as u8, false, pkts_b[i].data()) .add_symbol(1, i as u16, false, pkts_b[i].data())
.unwrap(); .unwrap();
} }

View File

@@ -1,8 +1,8 @@
//! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols. //! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols.
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder}; use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder};
use wzp_proto::error::FecError;
use wzp_proto::FecEncoder; use wzp_proto::FecEncoder;
use wzp_proto::error::FecError;
/// Maximum symbol size in bytes. Audio frames are typically < 200 bytes, /// Maximum symbol size in bytes. Audio frames are typically < 200 bytes,
/// but we pad to a uniform size within a block. /// but we pad to a uniform size within a block.
@@ -15,14 +15,19 @@ const LEN_PREFIX: usize = 2;
/// RaptorQ-based FEC encoder that groups audio frames into blocks /// RaptorQ-based FEC encoder that groups audio frames into blocks
/// and generates fountain-code repair symbols. /// and generates fountain-code repair symbols.
pub struct RaptorQFecEncoder { pub struct RaptorQFecEncoder {
/// Current block ID (wraps at u8). /// Current block ID (wraps at u16).
block_id: u8, block_id: u16,
/// Maximum source symbols per block. /// Maximum source symbols per block.
frames_per_block: usize, frames_per_block: usize,
/// Accumulated source symbols for the current block. /// Accumulated source symbols for the current block.
source_symbols: Vec<Vec<u8>>, source_symbols: Vec<Vec<u8>>,
/// Symbol size used for encoding (all symbols padded to this size). /// Symbol size used for encoding (all symbols padded to this size).
symbol_size: u16, symbol_size: u16,
/// True if at least one source symbol in the current block is a keyframe.
has_keyframe: bool,
/// Repair ratio to use when the block contains a keyframe.
/// If zero, the nominal ratio passed to [`generate_repair`] is used.
keyframe_ratio: f32,
} }
impl RaptorQFecEncoder { impl RaptorQFecEncoder {
@@ -36,9 +41,26 @@ impl RaptorQFecEncoder {
frames_per_block, frames_per_block,
source_symbols: Vec::with_capacity(frames_per_block), source_symbols: Vec::with_capacity(frames_per_block),
symbol_size, symbol_size,
has_keyframe: false,
keyframe_ratio: 0.0,
} }
} }
/// Set the repair ratio to use for blocks that contain at least one
/// keyframe source symbol.
///
/// When `keyframe_ratio > 0.0` and [`has_keyframe`](Self::has_keyframe)
/// is true, [`generate_repair`](FecEncoder::generate_repair) uses this
/// ratio instead of the nominal ratio passed by the caller.
pub fn set_keyframe_ratio(&mut self, ratio: f32) {
self.keyframe_ratio = ratio.max(0.0);
}
/// Returns true if the current block contains a keyframe source symbol.
pub fn has_keyframe(&self) -> bool {
self.has_keyframe
}
/// Create with default symbol size (256 bytes). /// Create with default symbol size (256 bytes).
pub fn with_defaults(frames_per_block: usize) -> Self { pub fn with_defaults(frames_per_block: usize) -> Self {
Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE) Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE)
@@ -54,8 +76,7 @@ impl RaptorQFecEncoder {
let payload_len = sym.len().min(max_payload); let payload_len = sym.len().min(max_payload);
let offset = i * ss; let offset = i * ss;
// Write 2-byte little-endian length prefix. // Write 2-byte little-endian length prefix.
data[offset..offset + LEN_PREFIX] data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
.copy_from_slice(&(payload_len as u16).to_le_bytes());
// Write payload after prefix. // Write payload after prefix.
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len] data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
.copy_from_slice(&sym[..payload_len]); .copy_from_slice(&sym[..payload_len]);
@@ -75,17 +96,36 @@ impl FecEncoder for RaptorQFecEncoder {
Ok(()) Ok(())
} }
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError> { fn add_source_symbol_with_keyframe(
&mut self,
data: &[u8],
is_keyframe: bool,
) -> Result<(), FecError> {
self.add_source_symbol(data)?;
if is_keyframe {
self.has_keyframe = true;
}
Ok(())
}
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError> {
if self.source_symbols.is_empty() { if self.source_symbols.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
let effective_ratio = if self.has_keyframe && self.keyframe_ratio > 0.0 {
self.keyframe_ratio
} else {
ratio
};
let block_data = self.build_block_data(); let block_data = self.build_block_data();
let config = ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size); let config =
let encoder = SourceBlockEncoder::new(self.block_id, &config, &block_data); ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
let encoder = SourceBlockEncoder::new((self.block_id & 0xFF) as u8, &config, &block_data);
let num_source = self.source_symbols.len() as u32; let num_source = self.source_symbols.len() as u32;
let num_repair = ((num_source as f32) * ratio).ceil() as u32; let num_repair = ((num_source as f32) * effective_ratio).ceil() as u32;
if num_repair == 0 { if num_repair == 0 {
return Ok(vec![]); return Ok(vec![]);
} }
@@ -93,11 +133,11 @@ impl FecEncoder for RaptorQFecEncoder {
// Generate repair packets starting from offset 0 (ESIs begin at num_source). // Generate repair packets starting from offset 0 (ESIs begin at num_source).
let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair); let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair);
let result: Vec<(u8, Vec<u8>)> = repair_packets let result: Vec<(u16, Vec<u8>)> = repair_packets
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(i, pkt): (usize, EncodingPacket)| { .map(|(i, pkt): (usize, EncodingPacket)| {
let idx = (num_source as u8).wrapping_add(i as u8); let idx = (num_source as u16).wrapping_add(i as u16);
(idx, pkt.data().to_vec()) (idx, pkt.data().to_vec())
}) })
.collect(); .collect();
@@ -105,14 +145,15 @@ impl FecEncoder for RaptorQFecEncoder {
Ok(result) Ok(result)
} }
fn finalize_block(&mut self) -> Result<u8, FecError> { fn finalize_block(&mut self) -> Result<u16, FecError> {
let completed = self.block_id; let completed = self.block_id;
self.block_id = self.block_id.wrapping_add(1); self.block_id = self.block_id.wrapping_add(1);
self.source_symbols.clear(); self.source_symbols.clear();
self.has_keyframe = false;
Ok(completed) Ok(completed)
} }
fn current_block_id(&self) -> u8 { fn current_block_id(&self) -> u16 {
self.block_id self.block_id
} }
@@ -130,8 +171,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
let max_payload = ss - LEN_PREFIX; let max_payload = ss - LEN_PREFIX;
let payload_len = sym.len().min(max_payload); let payload_len = sym.len().min(max_payload);
let offset = i * ss; let offset = i * ss;
data[offset..offset + LEN_PREFIX] data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
.copy_from_slice(&(payload_len as u16).to_le_bytes());
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len] data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
.copy_from_slice(&sym[..payload_len]); .copy_from_slice(&sym[..payload_len]);
} }
@@ -141,7 +181,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
/// Helper: build source `EncodingPacket`s for a given block. Useful for /// Helper: build source `EncodingPacket`s for a given block. Useful for
/// the decoder tests and interleaving. /// the decoder tests and interleaving.
pub fn source_packets_for_block( pub fn source_packets_for_block(
block_id: u8, block_id: u16,
symbols: &[Vec<u8>], symbols: &[Vec<u8>],
symbol_size: u16, symbol_size: u16,
) -> Vec<EncodingPacket> { ) -> Vec<EncodingPacket> {
@@ -151,21 +191,21 @@ pub fn source_packets_for_block(
.map(|i| { .map(|i| {
let offset = i * ss; let offset = i * ss;
let sym_data = data[offset..offset + ss].to_vec(); let sym_data = data[offset..offset + ss].to_vec();
EncodingPacket::new(PayloadId::new(block_id, i as u32), sym_data) EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, i as u32), sym_data)
}) })
.collect() .collect()
} }
/// Helper: generate repair packets for the given source symbols. /// Helper: generate repair packets for the given source symbols.
pub fn repair_packets_for_block( pub fn repair_packets_for_block(
block_id: u8, block_id: u16,
symbols: &[Vec<u8>], symbols: &[Vec<u8>],
symbol_size: u16, symbol_size: u16,
ratio: f32, ratio: f32,
) -> Vec<EncodingPacket> { ) -> Vec<EncodingPacket> {
let data = build_prefixed_block_data(symbols, symbol_size); let data = build_prefixed_block_data(symbols, symbol_size);
let config = ObjectTransmissionInformation::with_defaults(data.len() as u64, symbol_size); let config = ObjectTransmissionInformation::with_defaults(data.len() as u64, symbol_size);
let encoder = SourceBlockEncoder::new(block_id, &config, &data); let encoder = SourceBlockEncoder::new((block_id & 0xFF) as u8, &config, &data);
let num_source = symbols.len() as u32; let num_source = symbols.len() as u32;
let num_repair = ((num_source as f32) * ratio).ceil() as u32; let num_repair = ((num_source as f32) * ratio).ceil() as u32;
encoder.repair_packets(0, num_repair) encoder.repair_packets(0, num_repair)
@@ -201,14 +241,70 @@ mod tests {
} }
#[test] #[test]
fn block_id_wraps() { fn block_id_wraps_u16() {
let mut enc = RaptorQFecEncoder::with_defaults(1); let mut enc = RaptorQFecEncoder::with_defaults(1);
for expected in 0..=255u8 { // Advance 300 blocks and verify no panic + monotonic increment.
for expected in 0..300u16 {
assert_eq!(enc.current_block_id(), expected); assert_eq!(enc.current_block_id(), expected);
enc.add_source_symbol(&[expected; 10]).unwrap(); enc.add_source_symbol(&[0u8; 10]).unwrap();
enc.finalize_block().unwrap(); enc.finalize_block().unwrap();
} }
// After 256 blocks, wraps back to 0 // Explicitly test wrap at u16 boundary.
assert_eq!(enc.current_block_id(), 0); let mut enc2 = RaptorQFecEncoder::with_defaults(1);
enc2.block_id = u16::MAX;
enc2.add_source_symbol(&[0u8; 10]).unwrap();
let id = enc2.finalize_block().unwrap();
assert_eq!(id, u16::MAX);
assert_eq!(enc2.current_block_id(), 0);
}
#[test]
fn keyframe_boost_uses_higher_ratio() {
// Non-keyframe block with nominal ratio 0.2 → ceil(5 * 0.2) = 1 repair.
let mut enc_normal = RaptorQFecEncoder::with_defaults(5);
enc_normal.set_keyframe_ratio(0.8);
for i in 0..5 {
enc_normal
.add_source_symbol_with_keyframe(&[i as u8; 100], false)
.unwrap();
}
let normal_repair = enc_normal.generate_repair(0.2).unwrap();
assert_eq!(normal_repair.len(), 1);
// Keyframe block with same nominal ratio but boost to 0.8 → ceil(5 * 0.8) = 4 repairs.
let mut enc_key = RaptorQFecEncoder::with_defaults(5);
enc_key.set_keyframe_ratio(0.8);
for i in 0..5 {
enc_key
.add_source_symbol_with_keyframe(&[i as u8; 100], i == 2)
.unwrap();
}
let keyframe_repair = enc_key.generate_repair(0.2).unwrap();
assert_eq!(keyframe_repair.len(), 4);
}
#[test]
fn non_keyframe_block_uses_nominal_ratio() {
let mut enc = RaptorQFecEncoder::with_defaults(5);
enc.set_keyframe_ratio(0.8);
for i in 0..5 {
enc.add_source_symbol_with_keyframe(&[i as u8; 100], false)
.unwrap();
}
let repair = enc.generate_repair(0.2).unwrap();
assert_eq!(repair.len(), 1); // ceil(5 * 0.2) = 1
}
#[test]
fn finalize_clears_keyframe_flag() {
let mut enc = RaptorQFecEncoder::with_defaults(2);
enc.add_source_symbol_with_keyframe(&[0u8; 10], true)
.unwrap();
assert!(enc.has_keyframe());
enc.finalize_block().unwrap();
assert!(!enc.has_keyframe());
} }
} }

View File

@@ -3,7 +3,7 @@
//! rather than one block fatally. //! rather than one block fatally.
/// A symbol ready for transmission: (block_id, symbol_index, is_repair, data). /// A symbol ready for transmission: (block_id, symbol_index, is_repair, data).
pub type Symbol = (u8, u8, bool, Vec<u8>); pub type Symbol = (u16, u16, bool, Vec<u8>);
/// Temporal interleaver that mixes symbols across multiple FEC blocks. /// Temporal interleaver that mixes symbols across multiple FEC blocks.
pub struct Interleaver { pub struct Interleaver {
@@ -64,13 +64,13 @@ mod tests {
let interleaver = Interleaver::with_default_depth(); let interleaver = Interleaver::with_default_depth();
let block_a: Vec<Symbol> = (0..3) let block_a: Vec<Symbol> = (0..3)
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8])) .map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
.collect(); .collect();
let block_b: Vec<Symbol> = (0..3) let block_b: Vec<Symbol> = (0..3)
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8])) .map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
.collect(); .collect();
let block_c: Vec<Symbol> = (0..3) let block_c: Vec<Symbol> = (0..3)
.map(|i| (2u8, i as u8, false, vec![0xC0 + i as u8])) .map(|i| (2u16, i as u16, false, vec![0xC0 + i as u8]))
.collect(); .collect();
let result = interleaver.interleave(&[block_a, block_b, block_c]); let result = interleaver.interleave(&[block_a, block_b, block_c]);
@@ -96,10 +96,10 @@ mod tests {
let interleaver = Interleaver::new(2); let interleaver = Interleaver::new(2);
let block_a: Vec<Symbol> = (0..3) let block_a: Vec<Symbol> = (0..3)
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8])) .map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
.collect(); .collect();
let block_b: Vec<Symbol> = (0..1) let block_b: Vec<Symbol> = (0..1)
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8])) .map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
.collect(); .collect();
let result = interleaver.interleave(&[block_a, block_b]); let result = interleaver.interleave(&[block_a, block_b]);
@@ -128,7 +128,7 @@ mod tests {
let blocks: Vec<Vec<Symbol>> = (0..3) let blocks: Vec<Vec<Symbol>> = (0..3)
.map(|b| { .map(|b| {
(0..6) (0..6)
.map(|i| (b as u8, i as u8, false, vec![b as u8; 10])) .map(|i| (b as u16, i as u16, false, vec![b as u8; 10]))
.collect() .collect()
}) })
.collect(); .collect();
@@ -146,7 +146,10 @@ mod tests {
// Each block should lose exactly 2 (6 losses / 3 blocks) // Each block should lose exactly 2 (6 losses / 3 blocks)
for &loss in &losses_per_block { for &loss in &losses_per_block {
assert_eq!(loss, 2, "Each block should lose at most 2 symbols from a burst of 6"); assert_eq!(
loss, 2,
"Each block should lose at most 2 symbols from a burst of 6"
);
} }
} }
} }

View File

@@ -16,7 +16,9 @@ pub mod encoder;
pub mod interleave; pub mod interleave;
pub use adaptive::AdaptiveFec; pub use adaptive::AdaptiveFec;
pub use block_manager::{DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState}; pub use block_manager::{
DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState,
};
pub use decoder::RaptorQFecDecoder; pub use decoder::RaptorQFecDecoder;
pub use encoder::RaptorQFecEncoder; pub use encoder::RaptorQFecEncoder;
pub use interleave::Interleaver; pub use interleave::Interleaver;
@@ -24,9 +26,7 @@ pub use interleave::Interleaver;
pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile}; pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile};
/// Create an encoder/decoder pair configured for the given quality profile. /// Create an encoder/decoder pair configured for the given quality profile.
pub fn create_fec_pair( pub fn create_fec_pair(profile: &QualityProfile) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
profile: &QualityProfile,
) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
let cfg = AdaptiveFec::from_profile(profile); let cfg = AdaptiveFec::from_profile(profile);
let encoder = cfg.build_encoder(); let encoder = cfg.build_encoder();
let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size); let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size);

View File

@@ -24,7 +24,10 @@ fn main() {
let oboe_dir = fetch_oboe(); let oboe_dir = fetch_oboe();
match oboe_dir { match oboe_dir {
Some(oboe_path) => { Some(oboe_path) => {
println!("cargo:warning=wzp-native: building with Oboe from {:?}", oboe_path); println!(
"cargo:warning=wzp-native: building with Oboe from {:?}",
oboe_path
);
let mut build = cc::Build::new(); let mut build = cc::Build::new();
build build
.cpp(true) .cpp(true)
@@ -96,7 +99,12 @@ fn fetch_oboe() -> Option<PathBuf> {
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let oboe_dir = out_dir.join("oboe"); let oboe_dir = out_dir.join("oboe");
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() { if oboe_dir
.join("include")
.join("oboe")
.join("Oboe.h")
.exists()
{
return Some(oboe_dir); return Some(oboe_dir);
} }
@@ -111,7 +119,14 @@ fn fetch_oboe() -> Option<PathBuf> {
.status(); .status();
match status { match status {
Ok(s) if s.success() && oboe_dir.join("include").join("oboe").join("Oboe.h").exists() => { Ok(s)
if s.success()
&& oboe_dir
.join("include")
.join("oboe")
.join("Oboe.h")
.exists() =>
{
Some(oboe_dir) Some(oboe_dir)
} }
_ => None, _ => None,

View File

@@ -404,12 +404,14 @@ int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
{ {
auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(2000); auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(2000);
int poll_count = 0; int poll_count = 0;
bool streams_started = false;
while (std::chrono::steady_clock::now() < deadline) { while (std::chrono::steady_clock::now() < deadline) {
auto cap_state = g_capture_stream->getState(); auto cap_state = g_capture_stream->getState();
auto play_state = g_playout_stream->getState(); auto play_state = g_playout_stream->getState();
if (cap_state == oboe::StreamState::Started && if (cap_state == oboe::StreamState::Started &&
play_state == oboe::StreamState::Started) { play_state == oboe::StreamState::Started) {
LOGI("both streams Started after %d polls", poll_count); LOGI("both streams Started after %d polls", poll_count);
streams_started = true;
break; break;
} }
poll_count++; poll_count++;
@@ -420,6 +422,18 @@ int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
(int)g_capture_stream->getState(), (int)g_capture_stream->getState(),
(int)g_playout_stream->getState(), (int)g_playout_stream->getState(),
poll_count); poll_count);
if (!streams_started) {
LOGE("Timed out waiting for Oboe streams to reach Started state");
g_running.store(false, std::memory_order_release);
g_rings_valid.store(false, std::memory_order_release);
g_capture_stream->requestStop();
g_playout_stream->requestStop();
g_capture_stream->close();
g_playout_stream->close();
g_capture_stream.reset();
g_playout_stream.reset();
return -6;
}
} }
LOGI("Oboe started: sr=%d burst=%d ch=%d", LOGI("Oboe started: sr=%d burst=%d ch=%d",

View File

@@ -116,7 +116,11 @@ impl RingBuffer {
let w = self.write_idx.load(Ordering::Acquire); let w = self.write_idx.load(Ordering::Acquire);
let r = self.read_idx.load(Ordering::Relaxed); let r = self.read_idx.load(Ordering::Relaxed);
let avail = w - r; let avail = w - r;
if avail < 0 { (avail + self.capacity as i32) as usize } else { avail as usize } if avail < 0 {
(avail + self.capacity as i32) as usize
} else {
avail as usize
}
} }
fn available_write(&self) -> usize { fn available_write(&self) -> usize {
@@ -132,9 +136,13 @@ impl RingBuffer {
let cap = self.capacity; let cap = self.capacity;
let buf_ptr = self.buf.as_ptr() as *mut i16; let buf_ptr = self.buf.as_ptr() as *mut i16;
for sample in &data[..count] { for sample in &data[..count] {
unsafe { *buf_ptr.add(w) = *sample; } unsafe {
*buf_ptr.add(w) = *sample;
}
w += 1; w += 1;
if w >= cap { w = 0; } if w >= cap {
w = 0;
}
} }
self.write_idx.store(w as i32, Ordering::Release); self.write_idx.store(w as i32, Ordering::Release);
count count
@@ -149,9 +157,13 @@ impl RingBuffer {
let cap = self.capacity; let cap = self.capacity;
let buf_ptr = self.buf.as_ptr(); let buf_ptr = self.buf.as_ptr();
for slot in &mut out[..count] { for slot in &mut out[..count] {
unsafe { *slot = *buf_ptr.add(r); } unsafe {
*slot = *buf_ptr.add(r);
}
r += 1; r += 1;
if r >= cap { r = 0; } if r >= cap {
r = 0;
}
} }
self.read_idx.store(r as i32, Ordering::Release); self.read_idx.store(r as i32, Ordering::Release);
count count
@@ -316,17 +328,27 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
// has stopped firing → restart the streams. This is the // has stopped firing → restart the streams. This is the
// self-healing behavior that makes rejoin work: teardown + // self-healing behavior that makes rejoin work: teardown +
// rebuild clears whatever HAL state locked up the callback. // rebuild clears whatever HAL state locked up the callback.
let current_read_idx = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed); let current_read_idx = b
let last_read_idx = b.playout_last_read_idx.load(std::sync::atomic::Ordering::Relaxed); .playout
.read_idx
.load(std::sync::atomic::Ordering::Relaxed);
let last_read_idx = b
.playout_last_read_idx
.load(std::sync::atomic::Ordering::Relaxed);
if current_read_idx == last_read_idx { if current_read_idx == last_read_idx {
let stall = b.playout_stall_writes.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let stall = b
.playout_stall_writes
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if stall >= 50 { if stall >= 50 {
// Callback hasn't drained anything in ~1 second. // Callback hasn't drained anything in ~1 second.
// Force a stream restart. // Force a stream restart.
unsafe { unsafe {
android_log("playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams"); android_log(
"playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams",
);
} }
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed); b.playout_stall_writes
.store(0, std::sync::atomic::Ordering::Relaxed);
// Release the started lock, stop, re-start. // Release the started lock, stop, re-start.
// This is the same logic as the Rust-side // This is the same logic as the Rust-side
// audio_stop() + audio_start() but done inline // audio_stop() + audio_start() but done inline
@@ -341,10 +363,18 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
} }
} }
// Clear the rings so the restart doesn't read stale data // Clear the rings so the restart doesn't read stale data
b.playout.write_idx.store(0, std::sync::atomic::Ordering::Relaxed); b.playout
b.playout.read_idx.store(0, std::sync::atomic::Ordering::Relaxed); .write_idx
b.capture.write_idx.store(0, std::sync::atomic::Ordering::Relaxed); .store(0, std::sync::atomic::Ordering::Relaxed);
b.capture.read_idx.store(0, std::sync::atomic::Ordering::Relaxed); b.playout
.read_idx
.store(0, std::sync::atomic::Ordering::Relaxed);
b.capture
.write_idx
.store(0, std::sync::atomic::Ordering::Relaxed);
b.capture
.read_idx
.store(0, std::sync::atomic::Ordering::Relaxed);
// Re-start (stall detector — always non-BT mode) // Re-start (stall detector — always non-BT mode)
let config = WzpOboeConfig { let config = WzpOboeConfig {
sample_rate: 48_000, sample_rate: 48_000,
@@ -367,30 +397,49 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
if let Ok(mut started) = b.started.lock() { if let Ok(mut started) = b.started.lock() {
*started = true; *started = true;
} }
unsafe { android_log("playout restart OK — Oboe streams rebuilt"); } unsafe {
} else { android_log("playout restart OK — Oboe streams rebuilt");
unsafe { android_log(&format!("playout restart FAILED: {ret}")); }
} }
b.playout_last_read_idx.store(0, std::sync::atomic::Ordering::Relaxed); } else {
unsafe {
android_log(&format!("playout restart FAILED: {ret}"));
}
}
b.playout_last_read_idx
.store(0, std::sync::atomic::Ordering::Relaxed);
return 0; // caller will retry on next frame return 0; // caller will retry on next frame
} }
} else { } else {
// read_idx advanced — callback is alive, reset counter // read_idx advanced — callback is alive, reset counter
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed); b.playout_stall_writes
b.playout_last_read_idx.store(current_read_idx, std::sync::atomic::Ordering::Relaxed); .store(0, std::sync::atomic::Ordering::Relaxed);
b.playout_last_read_idx
.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
} }
let before_w = b.playout.write_idx.load(std::sync::atomic::Ordering::Relaxed); let before_w = b
let before_r = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed); .playout
.write_idx
.load(std::sync::atomic::Ordering::Relaxed);
let before_r = b
.playout
.read_idx
.load(std::sync::atomic::Ordering::Relaxed);
let written = b.playout.write(slice); let written = b.playout.write(slice);
// First few writes: log ring state + sample range so we can compare what // First few writes: log ring state + sample range so we can compare what
// engine.rs hands us to what the C++ playout callback reads. // engine.rs hands us to what the C++ playout callback reads.
let first_writes = b.playout_write_log_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let first_writes = b
.playout_write_log_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if first_writes < 3 || first_writes % 50 == 0 { if first_writes < 3 || first_writes % 50 == 0 {
let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64); let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64);
for &s in slice.iter() { for &s in slice.iter() {
if s < lo { lo = s; } if s < lo {
if s > hi { hi = s; } lo = s;
}
if s > hi {
hi = s;
}
sumsq += (s as i64) * (s as i64); sumsq += (s as i64) * (s as i64);
} }
let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32; let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32;
@@ -398,7 +447,8 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
let avail_r_after = b.playout.available_read(); let avail_r_after = b.playout.available_read();
let msg = format!( let msg = format!(
"playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}", "playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}",
slice.len(), written slice.len(),
written
); );
unsafe { unsafe {
android_log(msg.as_str()); android_log(msg.as_str());
@@ -422,7 +472,9 @@ unsafe fn android_log(msg: &str) {
let mut buf = Vec::with_capacity(msg.len() + 1); let mut buf = Vec::with_capacity(msg.len() + 1);
buf.extend_from_slice(msg.as_bytes()); buf.extend_from_slice(msg.as_bytes());
buf.push(0); buf.push(0);
unsafe { __android_log_write(4, tag.as_ptr(), buf.as_ptr()); } unsafe {
__android_log_write(4, tag.as_ptr(), buf.as_ptr());
}
} }
#[cfg(not(target_os = "android"))] #[cfg(not(target_os = "android"))]

View File

@@ -20,3 +20,4 @@ tracing = "0.1"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
serde_json = "1" serde_json = "1"
bincode = "1"

View File

@@ -7,10 +7,11 @@
//! Control (GCC). //! Control (GCC).
use std::collections::VecDeque; use std::collections::VecDeque;
use std::time::Instant; use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use crate::packet::QualityReport;
use crate::QualityProfile; use crate::QualityProfile;
use crate::packet::QualityReport;
/// Network congestion state derived from delay and loss signals. /// Network congestion state derived from delay and loss signals.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -158,6 +159,16 @@ pub struct BandwidthEstimator {
loss_detector: LossBasedDetector, loss_detector: LossBasedDetector,
/// Last update timestamp. /// Last update timestamp.
last_update: Option<Instant>, last_update: Option<Instant>,
// ── Transport-feedback BWE (T2.2) ──
/// Congestion-window-derived bandwidth estimate in bits per second.
cwnd_bps: AtomicU64,
/// Peer REMB (Receiver Estimated Maximum Bitrate) in bits per second.
peer_remb_bps: AtomicU64,
/// EWMA-smoothed bandwidth estimate in bits per second.
smoothed_bps: AtomicU64,
/// Last time `smoothed_bps` was updated (UNIX epoch millis).
last_smoothed_ms: AtomicU64,
} }
/// Multiplicative decrease factor applied on congestion (15% reduction). /// Multiplicative decrease factor applied on congestion (15% reduction).
@@ -179,6 +190,10 @@ impl BandwidthEstimator {
delay_detector: DelayBasedDetector::new(), delay_detector: DelayBasedDetector::new(),
loss_detector: LossBasedDetector::new(), loss_detector: LossBasedDetector::new(),
last_update: None, last_update: None,
cwnd_bps: AtomicU64::new(0),
peer_remb_bps: AtomicU64::new(u64::MAX),
smoothed_bps: AtomicU64::new(0),
last_smoothed_ms: AtomicU64::new(0),
} }
} }
@@ -250,6 +265,64 @@ impl BandwidthEstimator {
QualityProfile::CATASTROPHIC QualityProfile::CATASTROPHIC
} }
} }
// ── Transport-feedback BWE (T2.2) ──
/// Update from QUIC path stats.
///
/// Computes `cwnd_bps = cwnd_bytes * 8 / rtt_s` and feeds it into the
/// smoothed estimate.
pub fn update_from_path(&self, cwnd_bytes: u64, _bytes_in_flight: u64, rtt_ms: u32) {
let rtt_s = rtt_ms.max(1) as f64 / 1000.0;
let cwnd_bps = ((cwnd_bytes * 8) as f64 / rtt_s) as u64;
self.cwnd_bps.store(cwnd_bps, Relaxed);
self.update_smoothed(cwnd_bps);
}
/// Update from a peer's `TransportFeedback` REMB value.
pub fn update_from_peer(&self, fb_remb_bps: u32) {
let remb = fb_remb_bps as u64;
self.peer_remb_bps.store(remb, Relaxed);
self.update_smoothed(remb);
}
/// Target sending bitrate in bits per second.
///
/// Returns 90% of the minimum between the congestion-window estimate
/// and the peer REMB estimate.
pub fn target_send_bps(&self) -> u64 {
let cwnd = self.cwnd_bps.load(Relaxed);
let remb = self.peer_remb_bps.load(Relaxed);
let m = cwnd.min(remb);
(m as f64 * 0.9) as u64
}
/// EWMA-smoothed bandwidth estimate in bits per second.
pub fn smoothed_bps(&self) -> u64 {
self.smoothed_bps.load(Relaxed)
}
/// Apply EWMA smoothing with a 2-second half-life.
fn update_smoothed(&self, new_bps: u64) {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let last_ms = self.last_smoothed_ms.load(Relaxed);
let dt_ms = now_ms.saturating_sub(last_ms);
let current = self.smoothed_bps.load(Relaxed);
let updated = if current == 0 || dt_ms == 0 {
new_bps
} else {
let alpha = 1.0 - 0.5_f64.powf(dt_ms as f64 / 2000.0);
let s = current as f64 * (1.0 - alpha) + new_bps as f64 * alpha;
s as u64
};
self.smoothed_bps.store(updated, Relaxed);
self.last_smoothed_ms.store(now_ms, Relaxed);
}
} }
#[cfg(test)] #[cfg(test)]
@@ -396,10 +469,7 @@ mod tests {
// Below 8 => CATASTROPHIC // Below 8 => CATASTROPHIC
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0); let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
assert_eq!( assert_eq!(bwe_cat.recommended_profile(), QualityProfile::CATASTROPHIC);
bwe_cat.recommended_profile(),
QualityProfile::CATASTROPHIC
);
// High bandwidth // High bandwidth
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0); let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
@@ -451,4 +521,46 @@ mod tests {
} }
assert!(det.is_congested()); assert!(det.is_congested());
} }
#[test]
fn target_send_bps_uses_min_of_cwnd_and_remb() {
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
// cwnd_bps = 100_000, remb = 200_000 → min = 100_000 → 90%
bwe.update_from_path(1250, 0, 100); // 1250*8 / 0.1 = 100_000
bwe.update_from_peer(200_000);
assert_eq!(bwe.target_send_bps(), 90_000);
}
#[test]
fn target_send_bps_with_zero_cwnd_uses_remb() {
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
// Default cwnd is 0, remb is u64::MAX (default).
// 0.min(u64::MAX) = 0 → 90% = 0
assert_eq!(bwe.target_send_bps(), 0);
bwe.update_from_peer(100_000);
// cwnd still 0
assert_eq!(bwe.target_send_bps(), 0);
}
#[test]
fn smoothed_bps_ewma_converges() {
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
bwe.update_from_path(1250, 0, 100); // 100_000 bps
let s1 = bwe.smoothed_bps();
assert_eq!(s1, 100_000);
// Immediately update with same value — dt ≈ 0, so should stay at 100_000
bwe.update_from_path(1250, 0, 100);
let s2 = bwe.smoothed_bps();
assert_eq!(s2, 100_000);
// Sleep a bit so dt is non-zero, then update with a much higher value.
std::thread::sleep(std::time::Duration::from_millis(100));
bwe.update_from_path(12500, 0, 100); // 1_000_000 bps
let s3 = bwe.smoothed_bps();
assert!(s3 > 100_000, "smoothed should increase toward 1M: {s3}");
// With 100ms dt, alpha ≈ 0.03, so smoothed should be ~100k * 0.97 + 1M * 0.03 ≈ 127k
assert!(s3 < 500_000, "smoothed should not jump too far: {s3}");
}
} }

View File

@@ -2,7 +2,8 @@ use serde::{Deserialize, Serialize};
/// Identifies the audio codec and bitrate configuration. /// Identifies the audio codec and bitrate configuration.
/// ///
/// Encoded as 4 bits in the media packet header. /// Encoded as 4 bits in the v1 media packet header, and as a full 8-bit
/// value in the v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)] #[repr(u8)]
pub enum CodecId { pub enum CodecId {
@@ -24,6 +25,16 @@ pub enum CodecId {
Opus48k = 7, Opus48k = 7,
/// Opus at 64kbps (studio high) /// Opus at 64kbps (studio high)
Opus64k = 8, Opus64k = 8,
/// H.264 baseline profile (video).
H264Baseline = 9,
// Reserved for video codecs; implementations land in PRD-video-multicodec.
// 10 => H264 main
// 11 => H265 main
// 13 => VP9
/// AV1 main profile (video).
Av1Main = 12,
/// H.265 main profile (video).
H265Main = 11,
} }
impl CodecId { impl CodecId {
@@ -39,6 +50,7 @@ impl CodecId {
Self::Codec2_3200 => 3_200, Self::Codec2_3200 => 3_200,
Self::Codec2_1200 => 1_200, Self::Codec2_1200 => 1_200,
Self::ComfortNoise => 0, Self::ComfortNoise => 0,
Self::H264Baseline | Self::H265Main | Self::Av1Main => 2_000_000,
} }
} }
@@ -50,16 +62,22 @@ impl CodecId {
Self::Codec2_3200 => 20, Self::Codec2_3200 => 20,
Self::Codec2_1200 => 40, Self::Codec2_1200 => 40,
Self::ComfortNoise => 20, Self::ComfortNoise => 20,
Self::H264Baseline | Self::H265Main | Self::Av1Main => 33,
} }
} }
/// Sample rate expected by this codec. /// Sample rate expected by this codec.
pub const fn sample_rate_hz(self) -> u32 { pub const fn sample_rate_hz(self) -> u32 {
match self { match self {
Self::Opus24k | Self::Opus16k | Self::Opus6k Self::Opus24k
| Self::Opus32k | Self::Opus48k | Self::Opus64k => 48_000, | Self::Opus16k
| Self::Opus6k
| Self::Opus32k
| Self::Opus48k
| Self::Opus64k => 48_000,
Self::Codec2_3200 | Self::Codec2_1200 => 8_000, Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
Self::ComfortNoise => 48_000, Self::ComfortNoise => 48_000,
Self::H264Baseline | Self::H265Main | Self::Av1Main => 48_000,
} }
} }
@@ -75,6 +93,9 @@ impl CodecId {
6 => Some(Self::Opus32k), 6 => Some(Self::Opus32k),
7 => Some(Self::Opus48k), 7 => Some(Self::Opus48k),
8 => Some(Self::Opus64k), 8 => Some(Self::Opus64k),
9 => Some(Self::H264Baseline),
11 => Some(Self::H265Main),
12 => Some(Self::Av1Main),
_ => None, _ => None,
} }
} }
@@ -84,10 +105,22 @@ impl CodecId {
self as u8 self as u8
} }
/// Returns true if this is a video codec variant.
pub const fn is_video(self) -> bool {
matches!(self, Self::H264Baseline | Self::H265Main | Self::Av1Main)
}
/// Returns true if this is an Opus variant. /// Returns true if this is an Opus variant.
pub const fn is_opus(self) -> bool { pub const fn is_opus(self) -> bool {
matches!(self, Self::Opus6k | Self::Opus16k | Self::Opus24k matches!(
| Self::Opus32k | Self::Opus48k | Self::Opus64k) self,
Self::Opus6k
| Self::Opus16k
| Self::Opus24k
| Self::Opus32k
| Self::Opus48k
| Self::Opus64k
)
} }
} }
@@ -102,6 +135,18 @@ pub struct QualityProfile {
pub frame_duration_ms: u8, pub frame_duration_ms: u8,
/// Number of source frames per FEC block. /// Number of source frames per FEC block.
pub frames_per_block: u8, pub frames_per_block: u8,
/// Bandwidth-allocation priority between audio and video.
#[serde(default)]
pub priority_mode: crate::PriorityMode,
/// Target video bitrate in kbps (set by quality controller, not handshake).
#[serde(default)]
pub video_bitrate_kbps: Option<u32>,
/// Target video resolution as (width, height).
#[serde(default)]
pub video_resolution: Option<(u16, u16)>,
/// Target video frame rate.
#[serde(default)]
pub video_fps: Option<u8>,
} }
impl QualityProfile { impl QualityProfile {
@@ -111,6 +156,10 @@ impl QualityProfile {
fec_ratio: 0.2, fec_ratio: 0.2,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Degraded conditions: Opus 6kbps, moderate FEC. /// Degraded conditions: Opus 6kbps, moderate FEC.
@@ -119,6 +168,10 @@ impl QualityProfile {
fec_ratio: 0.5, fec_ratio: 0.5,
frame_duration_ms: 40, frame_duration_ms: 40,
frames_per_block: 10, frames_per_block: 10,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Catastrophic conditions: Codec2 1.2kbps, heavy FEC. /// Catastrophic conditions: Codec2 1.2kbps, heavy FEC.
@@ -127,6 +180,10 @@ impl QualityProfile {
fec_ratio: 1.0, fec_ratio: 1.0,
frame_duration_ms: 40, frame_duration_ms: 40,
frames_per_block: 8, frames_per_block: 8,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Studio low: Opus 32kbps, minimal FEC. /// Studio low: Opus 32kbps, minimal FEC.
@@ -135,6 +192,10 @@ impl QualityProfile {
fec_ratio: 0.1, fec_ratio: 0.1,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Studio: Opus 48kbps, minimal FEC. /// Studio: Opus 48kbps, minimal FEC.
@@ -143,6 +204,10 @@ impl QualityProfile {
fec_ratio: 0.1, fec_ratio: 0.1,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Studio high: Opus 64kbps, minimal FEC. /// Studio high: Opus 64kbps, minimal FEC.
@@ -151,6 +216,10 @@ impl QualityProfile {
fec_ratio: 0.1, fec_ratio: 0.1,
frame_duration_ms: 20, frame_duration_ms: 20,
frames_per_block: 5, frames_per_block: 5,
priority_mode: crate::PriorityMode::AudioFirst,
video_bitrate_kbps: None,
video_resolution: None,
video_fps: None,
}; };
/// Estimated total bandwidth in kbps including FEC overhead. /// Estimated total bandwidth in kbps including FEC overhead.
@@ -159,3 +228,46 @@ impl QualityProfile {
base * (1.0 + self.fec_ratio) base * (1.0 + self.fec_ratio)
} }
} }
#[cfg(test)]
mod tests {
use super::{CodecId, QualityProfile};
use crate::PriorityMode;
#[test]
fn codec_id_unknown_values_rejected() {
for v in [10u8, 13].iter().copied().chain(14u8..=255) {
assert!(CodecId::from_wire(v).is_none(), "v={v}");
}
}
#[test]
fn h265_main_roundtrips() {
assert_eq!(CodecId::H265Main.to_wire(), 11);
assert_eq!(CodecId::from_wire(11), Some(CodecId::H265Main));
assert!(CodecId::H265Main.is_video());
assert_eq!(CodecId::H265Main.bitrate_bps(), 2_000_000);
assert_eq!(CodecId::H265Main.frame_duration_ms(), 33);
}
#[test]
fn av1_main_roundtrips() {
assert_eq!(CodecId::Av1Main.to_wire(), 12);
assert_eq!(CodecId::from_wire(12), Some(CodecId::Av1Main));
assert!(CodecId::Av1Main.is_video());
assert_eq!(CodecId::Av1Main.bitrate_bps(), 2_000_000);
assert_eq!(CodecId::Av1Main.frame_duration_ms(), 33);
}
#[test]
fn quality_profile_backward_compat_old_json() {
// Old JSON emitted before T5.1 has no priority_mode or video fields.
let old_json =
r#"{"codec":"Opus24k","fec_ratio":0.2,"frame_duration_ms":20,"frames_per_block":5}"#;
let parsed: QualityProfile = serde_json::from_str(old_json).unwrap();
assert_eq!(parsed.priority_mode, PriorityMode::AudioFirst);
assert_eq!(parsed.video_bitrate_kbps, None);
assert_eq!(parsed.video_resolution, None);
assert_eq!(parsed.video_fps, None);
}
}

View File

@@ -128,7 +128,11 @@ impl DredTuner {
self.initialized = true; self.initialized = true;
} else { } else {
// Fast-up (alpha=0.3), slow-down (alpha=0.05) asymmetric EWMA // Fast-up (alpha=0.3), slow-down (alpha=0.05) asymmetric EWMA
let alpha = if jitter_f > self.jitter_ewma { 0.3 } else { 0.05 }; let alpha = if jitter_f > self.jitter_ewma {
0.3
} else {
0.05
};
self.jitter_ewma = alpha * jitter_f + (1.0 - alpha) * self.jitter_ewma; self.jitter_ewma = alpha * jitter_f + (1.0 - alpha) * self.jitter_ewma;
} }

View File

@@ -37,7 +37,7 @@ pub enum CryptoError {
#[error("rekey failed: {0}")] #[error("rekey failed: {0}")]
RekeyFailed(String), RekeyFailed(String),
#[error("anti-replay: duplicate or old packet (seq={seq})")] #[error("anti-replay: duplicate or old packet (seq={seq})")]
ReplayDetected { seq: u16 }, ReplayDetected { seq: u32 },
#[error("internal crypto error: {0}")] #[error("internal crypto error: {0}")]
Internal(String), Internal(String),
} }

View File

@@ -81,9 +81,7 @@ impl AdaptivePlayoutDelay {
let jitter = (actual_delta - expected_delta).abs(); let jitter = (actual_delta - expected_delta).abs();
// Spike detection: check before EMA update // Spike detection: check before EMA update
if self.jitter_ema > 0.0 if self.jitter_ema > 0.0 && jitter > self.jitter_ema * self.spike_threshold_multiplier {
&& jitter > self.jitter_ema * self.spike_threshold_multiplier
{
self.spike_detected_at = Some(Instant::now()); self.spike_detected_at = Some(Instant::now());
} }
@@ -107,10 +105,8 @@ impl AdaptivePlayoutDelay {
self.target_delay = self.max_delay; self.target_delay = self.max_delay;
} else { } else {
// Convert jitter estimate to target delay in packets // Convert jitter estimate to target delay in packets
let raw_target = let raw_target = (self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
(self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin; self.target_delay = (raw_target as usize).clamp(self.min_delay, self.max_delay);
self.target_delay =
(raw_target as usize).clamp(self.min_delay, self.max_delay);
} }
} }
@@ -162,9 +158,9 @@ impl AdaptivePlayoutDelay {
/// Manages packet reordering, gap detection, and signals when PLC is needed. /// Manages packet reordering, gap detection, and signals when PLC is needed.
pub struct JitterBuffer { pub struct JitterBuffer {
/// Packets waiting to be consumed, ordered by sequence number. /// Packets waiting to be consumed, ordered by sequence number.
buffer: BTreeMap<u16, MediaPacket>, buffer: BTreeMap<u32, MediaPacket>,
/// Next sequence number expected for playout. /// Next sequence number expected for playout.
next_playout_seq: u16, next_playout_seq: u32,
/// Maximum buffer depth in number of packets. /// Maximum buffer depth in number of packets.
max_depth: usize, max_depth: usize,
/// Target buffer depth (adaptive, based on jitter). /// Target buffer depth (adaptive, based on jitter).
@@ -204,7 +200,7 @@ pub enum PlayoutResult {
/// A packet is available for playout. /// A packet is available for playout.
Packet(MediaPacket), Packet(MediaPacket),
/// The expected packet is missing — decoder should generate PLC. /// The expected packet is missing — decoder should generate PLC.
Missing { seq: u16 }, Missing { seq: u32 },
/// Buffer is empty or not yet filled to target depth. /// Buffer is empty or not yet filled to target depth.
NotReady, NotReady,
} }
@@ -278,9 +274,18 @@ impl JitterBuffer {
// federation room — reset instead of dropping. // federation room — reset instead of dropping.
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) { if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
let backward_distance = self.next_playout_seq.wrapping_sub(seq); let backward_distance = self.next_playout_seq.wrapping_sub(seq);
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected"); tracing::warn!(
seq,
next = self.next_playout_seq,
backward_distance,
"jitter: backward seq detected"
);
if backward_distance > 100 { if backward_distance > 100 {
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected"); tracing::info!(
seq,
next = self.next_playout_seq,
"jitter: RESET — new sender detected"
);
self.buffer.clear(); self.buffer.clear();
self.next_playout_seq = seq; self.next_playout_seq = seq;
self.stats.packets_late = 0; self.stats.packets_late = 0;
@@ -428,9 +433,18 @@ impl JitterBuffer {
// federation room — reset instead of dropping. // federation room — reset instead of dropping.
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) { if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
let backward_distance = self.next_playout_seq.wrapping_sub(seq); let backward_distance = self.next_playout_seq.wrapping_sub(seq);
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected"); tracing::warn!(
seq,
next = self.next_playout_seq,
backward_distance,
"jitter: backward seq detected"
);
if backward_distance > 100 { if backward_distance > 100 {
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected"); tracing::info!(
seq,
next = self.next_playout_seq,
"jitter: RESET — new sender detected"
);
self.buffer.clear(); self.buffer.clear();
self.next_playout_seq = seq; self.next_playout_seq = seq;
self.stats.packets_late = 0; self.stats.packets_late = 0;
@@ -489,7 +503,7 @@ impl JitterBuffer {
/// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic). /// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic).
/// Returns true if `a` comes before `b` in sequence space. /// Returns true if `a` comes before `b` in sequence space.
fn seq_before(a: u16, b: u16) -> bool { fn seq_before(a: u32, b: u32) -> bool {
let diff = b.wrapping_sub(a); let diff = b.wrapping_sub(a);
diff > 0 && diff < 0x8000 diff > 0 && diff < 0x8000
} }
@@ -497,24 +511,23 @@ fn seq_before(a: u16, b: u16) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::CodecId;
use crate::MediaType;
use crate::packet::{MediaHeader, MediaPacket}; use crate::packet::{MediaHeader, MediaPacket};
use bytes::Bytes; use bytes::Bytes;
use crate::CodecId;
fn make_packet(seq: u16) -> MediaPacket { fn make_packet(seq: u32) -> MediaPacket {
MediaPacket { MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k, codec_id: CodecId::Opus24k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq, seq,
timestamp: seq as u32 * 20, timestamp: seq * 20,
fec_block: 0, fec_block: 0,
fec_symbol: 0,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(vec![0u8; 60]), payload: Bytes::from(vec![0u8; 60]),
quality_report: None, quality_report: None,
@@ -598,7 +611,7 @@ mod tests {
fn seq_before_wrapping() { fn seq_before_wrapping() {
assert!(seq_before(0, 1)); assert!(seq_before(0, 1));
assert!(seq_before(65534, 65535)); assert!(seq_before(65534, 65535));
assert!(seq_before(65535, 0)); // wrap assert!(seq_before(u32::MAX, 0)); // wrap
assert!(!seq_before(1, 0)); assert!(!seq_before(1, 0));
assert!(!seq_before(5, 5)); // equal assert!(!seq_before(5, 5)); // equal
} }
@@ -800,7 +813,7 @@ mod tests {
let mut jb = JitterBuffer::new_adaptive(3, 50); let mut jb = JitterBuffer::new_adaptive(3, 50);
// Push packets with consistent timing // Push packets with consistent timing
for i in 0u16..20 { for i in 0u32..20 {
let pkt = make_packet(i); let pkt = make_packet(i);
let arrival_ms = i as u64 * 20; let arrival_ms = i as u64 * 20;
jb.push_with_arrival(pkt, arrival_ms); jb.push_with_arrival(pkt, arrival_ms);

View File

@@ -17,21 +17,25 @@ pub mod codec_id;
pub mod dred_tuner; pub mod dred_tuner;
pub mod error; pub mod error;
pub mod jitter; pub mod jitter;
pub mod media_type;
pub mod packet; pub mod packet;
pub mod priority_mode;
pub mod quality; pub mod quality;
pub mod session; pub mod session;
pub mod traits; pub mod traits;
// Re-export key types at crate root for convenience. // Re-export key types at crate root for convenience.
pub use codec_id::{CodecId, QualityProfile};
pub use error::*;
pub use packet::{
CallAcceptMode, HangupReason, MediaHeader, MediaPacket, MiniFrameContext, MiniHeader,
PresenceUser, QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL,
FRAME_TYPE_MINI,
};
pub use bandwidth::{BandwidthEstimator, CongestionState}; pub use bandwidth::{BandwidthEstimator, CongestionState};
pub use codec_id::{CodecId, QualityProfile};
pub use dred_tuner::{DredTuner, DredTuning}; pub use dred_tuner::{DredTuner, DredTuning};
pub use error::*;
pub use media_type::MediaType;
pub use packet::{
CallAcceptMode, FRAME_TYPE_FULL, FRAME_TYPE_MINI, HangupReason, MediaHeader, MediaHeaderV2,
MediaPacket, MiniFrameContext, MiniFrameContextV2, MiniHeader, MiniHeaderV2, PresenceUser,
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, default_signal_version,
};
pub use priority_mode::PriorityMode;
pub use quality::{AdaptiveQualityController, NetworkContext, Tier}; pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
pub use session::{Session, SessionEvent, SessionState}; pub use session::{Session, SessionEvent, SessionState};
pub use traits::*; pub use traits::*;

View File

@@ -0,0 +1,57 @@
use serde::{Deserialize, Serialize};
/// Media stream type carried in a v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum MediaType {
/// Encoded speech / music (Opus, Codec2, ComfortNoise).
Audio = 0,
/// Encoded video access unit (H.264, H.265, AV1; PRD-video-multicodec).
Video = 1,
/// Opaque payload not interpreted by the relay (reserved).
Data = 2,
/// In-band control message carried on the media plane (reserved).
Control = 3,
}
impl MediaType {
/// Encode to the wire byte representation (`self as u8`).
pub const fn to_wire(self) -> u8 {
self as u8
}
/// Decode from a wire byte. Returns `None` for values outside 0..=3.
pub const fn from_wire(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Audio),
1 => Some(Self::Video),
2 => Some(Self::Data),
3 => Some(Self::Control),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn media_type_roundtrip() {
for mt in [
MediaType::Audio,
MediaType::Video,
MediaType::Data,
MediaType::Control,
] {
assert_eq!(MediaType::from_wire(mt.to_wire()), Some(mt));
}
}
#[test]
fn media_type_unknown_rejected() {
for v in 4u8..=255 {
assert!(MediaType::from_wire(v).is_none(), "v={v}");
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
//! Priority mode for bandwidth allocation between audio and video.
//!
//! See `docs/PRD/PRD-video-quality-priority.md` for the full design.
use serde::{Deserialize, Serialize};
/// Bandwidth-allocation policy between audio and video.
///
/// Carried on [`QualityProfile`](crate::QualityProfile) and mutable at
/// runtime via [`SignalMessage::SetPriorityMode`](crate::SignalMessage).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum PriorityMode {
/// Audio gets its floor first; video gets the remainder.
/// Default for voice/video calls.
#[default]
AudioFirst,
/// Video gets its floor first; audio degrades to Opus 16k floor.
VideoFirst,
/// Audio clamped to 16 kbps (intelligible speech); video gets remainder.
/// Falls back to slide mode when bandwidth drops below SD floor.
ScreenShare,
/// Proportional split (~15 % audio, ~85 % video).
Balanced,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn priority_mode_default_is_audio_first() {
assert_eq!(PriorityMode::default(), PriorityMode::AudioFirst);
}
}

View File

@@ -1,11 +1,13 @@
//! See also: [`crate::dred_tuner`] for continuous DRED tuning within a tier. //! See also: [`crate::dred_tuner`] for continuous DRED tuning within a tier.
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crate::BandwidthEstimator;
use crate::QualityProfile;
use crate::packet::QualityReport; use crate::packet::QualityReport;
use crate::traits::QualityController; use crate::traits::QualityController;
use crate::QualityProfile;
/// Network quality tier — drives codec and FEC selection. /// Network quality tier — drives codec and FEC selection.
/// ///
@@ -99,21 +101,16 @@ impl Tier {
} }
/// Describes the network transport type for context-aware quality decisions. /// Describes the network transport type for context-aware quality decisions.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum NetworkContext { pub enum NetworkContext {
WiFi, WiFi,
CellularLte, CellularLte,
Cellular5g, Cellular5g,
Cellular3g, Cellular3g,
#[default]
Unknown, Unknown,
} }
impl Default for NetworkContext {
fn default() -> Self {
Self::Unknown
}
}
/// Adaptive quality controller with hysteresis to prevent tier flapping. /// Adaptive quality controller with hysteresis to prevent tier flapping.
/// ///
/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular) /// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular)
@@ -139,6 +136,8 @@ pub struct AdaptiveQualityController {
probe: Option<ProbeState>, probe: Option<ProbeState>,
/// Time spent stable at the current tier (for probe trigger). /// Time spent stable at the current tier (for probe trigger).
stable_since: Option<Instant>, stable_since: Option<Instant>,
/// Optional bandwidth estimator for BWE-guarded upgrades.
bwe: Option<Arc<BandwidthEstimator>>,
} }
/// Threshold for downgrading (fast reaction to degradation). /// Threshold for downgrading (fast reaction to degradation).
@@ -192,6 +191,7 @@ impl AdaptiveQualityController {
fec_boost_amount: DEFAULT_FEC_BOOST, fec_boost_amount: DEFAULT_FEC_BOOST,
probe: None, probe: None,
stable_since: None, stable_since: None,
bwe: None,
} }
} }
@@ -259,6 +259,17 @@ impl AdaptiveQualityController {
self.stable_since = None; self.stable_since = None;
} }
/// Attach a bandwidth estimator for BWE-guarded tier transitions.
pub fn set_bandwidth_estimator(&mut self, bwe: Arc<BandwidthEstimator>) {
self.bwe = Some(bwe);
}
/// Return the bitrate ceiling (in bps) for a given tier, including FEC overhead.
fn tier_ceiling_bps(tier: Tier) -> u64 {
let kbps = tier.profile().total_bitrate_kbps();
(kbps * 1000.0) as u64
}
/// Get the effective downgrade threshold based on network context. /// Get the effective downgrade threshold based on network context.
fn downgrade_threshold(&self) -> u32 { fn downgrade_threshold(&self) -> u32 {
match self.network_context { match self.network_context {
@@ -301,6 +312,15 @@ impl AdaptiveQualityController {
if self.consecutive_up >= threshold { if self.consecutive_up >= threshold {
// Only upgrade one step at a time // Only upgrade one step at a time
if let Some(next_tier) = self.upgrade_one_step() { if let Some(next_tier) = self.upgrade_one_step() {
// BWE guard: require 130% headroom over target tier bitrate
if let Some(ref bwe) = self.bwe {
let required = (Self::tier_ceiling_bps(next_tier) * 130) / 100;
if bwe.target_send_bps() < required {
// Insufficient bandwidth — reset counter to prevent flapping
self.consecutive_up = 0;
return None;
}
}
self.current_tier = next_tier; self.current_tier = next_tier;
self.current_profile = next_tier.profile(); self.current_profile = next_tier.profile();
self.consecutive_up = 0; self.consecutive_up = 0;
@@ -340,8 +360,7 @@ impl AdaptiveQualityController {
if probe.bad_reports > PROBE_MAX_BAD { if probe.bad_reports > PROBE_MAX_BAD {
let _failed_probe = self.probe.take(); let _failed_probe = self.probe.take();
// Reset stable_since to trigger cooldown // Reset stable_since to trigger cooldown
self.stable_since = self.stable_since = Some(Instant::now() + Duration::from_secs(PROBE_COOLDOWN_SECS));
Some(Instant::now() + Duration::from_secs(PROBE_COOLDOWN_SECS));
return None; // stay at current tier return None; // stay at current tier
} }
@@ -535,6 +554,53 @@ mod tests {
} }
} }
#[test]
fn bwe_guard_blocks_upgrade_when_bandwidth_insufficient() {
let mut ctrl = AdaptiveQualityController::new();
// Force to catastrophic
let bad = make_report(50.0, 300);
for _ in 0..3 {
ctrl.observe(&bad);
}
assert_eq!(ctrl.tier(), Tier::Catastrophic);
// Attach a BWE with very low headroom.
// Degraded tier needs 6kbps * 1.5 FEC = 9kbps → 130% = 11.7kbps.
// Set target_send_bps ≈ 9_000 (below 11_700 threshold).
let bwe = Arc::new(BandwidthEstimator::new(1000.0, 1.0, 100_000.0));
bwe.update_from_path(1_000_000, 0, 10); // high cwnd
bwe.update_from_peer(10_000); // low remb → target = 9_000
ctrl.set_bandwidth_estimator(bwe.clone());
let good = make_report(0.5, 20);
for _ in 0..5 {
assert!(
ctrl.observe(&good).is_none(),
"upgrade should be blocked by low BWE"
);
}
assert_eq!(
ctrl.tier(),
Tier::Catastrophic,
"should remain at Catastrophic"
);
// Raise BWE well above the 130% threshold
bwe.update_from_peer(100_000); // target ≈ 90_000 bps
// Counter was reset, need another 5 good reports
for _ in 0..4 {
assert!(ctrl.observe(&good).is_none());
}
let result = ctrl.observe(&good);
assert!(
result.is_some(),
"upgrade should proceed with sufficient BWE"
);
assert_eq!(ctrl.tier(), Tier::Degraded);
}
#[test] #[test]
fn tier_classification() { fn tier_classification() {
// Studio tiers // Studio tiers
@@ -746,7 +812,10 @@ mod tests {
ctrl.observe(&degraded); // second bad — exceeds PROBE_MAX_BAD (1) ctrl.observe(&degraded); // second bad — exceeds PROBE_MAX_BAD (1)
// Probe should be cancelled // Probe should be cancelled
assert!(ctrl.probe.is_none(), "probe should be cancelled after bad reports"); assert!(
ctrl.probe.is_none(),
"probe should be cancelled after bad reports"
);
// Should still be at Studio32k (not upgraded) // Should still be at Studio32k (not upgraded)
assert_eq!(ctrl.current_tier, Tier::Studio32k); assert_eq!(ctrl.current_tier, Tier::Studio32k);
} }
@@ -775,6 +844,9 @@ mod tests {
let excellent = make_report(0.1, 10); let excellent = make_report(0.1, 10);
let result = ctrl.observe(&excellent); let result = ctrl.observe(&excellent);
assert!(result.is_none(), "should not probe when already at Studio64k"); assert!(
result.is_none(),
"should not probe when already at Studio64k"
);
} }
} }

View File

@@ -61,18 +61,34 @@ pub trait FecEncoder: Send + Sync {
/// Add a source symbol (one audio frame) to the current block. /// Add a source symbol (one audio frame) to the current block.
fn add_source_symbol(&mut self, data: &[u8]) -> Result<(), FecError>; fn add_source_symbol(&mut self, data: &[u8]) -> Result<(), FecError>;
/// Add a source symbol and mark whether it belongs to a keyframe.
///
/// When the block contains at least one keyframe source symbol,
/// [`generate_repair`] uses the configured keyframe ratio instead of the
/// nominal ratio.
///
/// Default implementation delegates to [`add_source_symbol`] and ignores
/// the keyframe flag.
fn add_source_symbol_with_keyframe(
&mut self,
data: &[u8],
_is_keyframe: bool,
) -> Result<(), FecError> {
self.add_source_symbol(data)
}
/// Generate repair symbols for the current block. /// Generate repair symbols for the current block.
/// ///
/// `ratio` is the repair overhead (e.g., 0.5 = 50% more symbols than source). /// `ratio` is the repair overhead (e.g., 0.5 = 50% more symbols than source).
/// Returns `(fec_symbol_index, repair_data)` pairs. /// Returns `(fec_symbol_index, repair_data)` pairs.
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError>; fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError>;
/// Finalize the current block and start a new one. /// Finalize the current block and start a new one.
/// Returns the block ID of the finalized block. /// Returns the block ID of the finalized block.
fn finalize_block(&mut self) -> Result<u8, FecError>; fn finalize_block(&mut self) -> Result<u16, FecError>;
/// Current block ID being built. /// Current block ID being built.
fn current_block_id(&self) -> u8; fn current_block_id(&self) -> u16;
/// Number of source symbols in the current block. /// Number of source symbols in the current block.
fn current_block_size(&self) -> usize; fn current_block_size(&self) -> usize;
@@ -83,8 +99,8 @@ pub trait FecDecoder: Send + Sync {
/// Feed a received symbol (source or repair) into the decoder. /// Feed a received symbol (source or repair) into the decoder.
fn add_symbol( fn add_symbol(
&mut self, &mut self,
block_id: u8, block_id: u16,
symbol_index: u8, symbol_index: u16,
is_repair: bool, is_repair: bool,
data: &[u8], data: &[u8],
) -> Result<(), FecError>; ) -> Result<(), FecError>;
@@ -93,10 +109,10 @@ pub trait FecDecoder: Send + Sync {
/// ///
/// Returns `None` if not yet decodable (insufficient symbols). /// Returns `None` if not yet decodable (insufficient symbols).
/// Returns `Some(Vec<source_frames>)` on success. /// Returns `Some(Vec<source_frames>)` on success.
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError>; fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError>;
/// Drop state for blocks older than `block_id`. /// Drop state for blocks older than `block_id`.
fn expire_before(&mut self, block_id: u8); fn expire_before(&mut self, block_id: u16);
} }
// ─── Crypto Traits ─────────────────────────────────────────────────────────── // ─── Crypto Traits ───────────────────────────────────────────────────────────

View File

@@ -7,9 +7,7 @@ fn main() {
.output(); .output();
let hash = match output { let hash = match output {
Ok(o) if o.status.success() => { Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout).trim().to_string(),
String::from_utf8_lossy(&o.stdout).trim().to_string()
}
_ => "unknown".to_string(), _ => "unknown".to_string(),
}; };

View File

@@ -0,0 +1,467 @@
//! Tier F audio scorer — behavioural entropy detection for abuse mitigation.
//!
//! Computes a `legitimacy ∈ [0, 1]` score over a 1030 s observation window.
//! Features: IAT CoV, payload-size bimodality, silence fraction, bitrate
//! deviation, and Q-flag cadence.
use std::collections::VecDeque;
use std::time::{Duration, Instant};
use wzp_proto::{CodecId, MediaHeader, MediaType};
use crate::verdict::Verdict;
/// Maximum samples kept in rolling windows.
const MAX_IAT_SAMPLES: usize = 200;
const MAX_SIZE_SAMPLES: usize = 200;
const MAX_Q_INTERVALS: usize = 32;
/// Silence threshold: payload below this many bytes is treated as silence / CN.
const SILENCE_SIZE_THRESHOLD: usize = 16;
/// Observation window for bitrate tracking.
const BITRATE_WINDOW_SECS: u64 = 30;
// Number of payload-size histogram bins.
// (SIZE_BINS reserved for future histogram-based bimodality)
/// Audio-specific behavioural scorer (Tier F).
pub struct AudioScorer {
/// Rolling inter-arrival times.
iat_samples: VecDeque<Duration>,
last_arrival: Option<Instant>,
/// Rolling payload sizes.
size_samples: VecDeque<usize>,
/// Count of packets below silence threshold.
silence_packets: u32,
/// Total packets observed in current window.
total_packets: u32,
/// Bitrate window.
window_start: Instant,
window_bytes: u64,
/// Q-flag arrival intervals.
q_intervals: VecDeque<Duration>,
last_q_flag: Option<Instant>,
/// Codec declared at first packet (used for nominal bitrate baseline).
declared_codec: Option<CodecId>,
}
impl AudioScorer {
pub fn new() -> Self {
Self {
iat_samples: VecDeque::with_capacity(MAX_IAT_SAMPLES),
last_arrival: None,
size_samples: VecDeque::with_capacity(MAX_SIZE_SAMPLES),
silence_packets: 0,
total_packets: 0,
window_start: Instant::now(),
window_bytes: 0,
q_intervals: VecDeque::with_capacity(MAX_Q_INTERVALS),
last_q_flag: None,
declared_codec: None,
}
}
/// Feed one packet into the scorer.
pub fn observe(&mut self, header: &MediaHeader, payload_len: usize, now: Instant) {
// Ignore non-audio traffic.
if header.media_type != MediaType::Audio {
return;
}
if self.declared_codec.is_none() {
self.declared_codec = Some(header.codec_id);
}
// IAT
if let Some(last) = self.last_arrival {
let iat = now.saturating_duration_since(last);
self.iat_samples.push_back(iat);
if self.iat_samples.len() > MAX_IAT_SAMPLES {
self.iat_samples.pop_front();
}
}
self.last_arrival = Some(now);
// Payload size
self.size_samples.push_back(payload_len);
if self.size_samples.len() > MAX_SIZE_SAMPLES {
self.size_samples.pop_front();
}
// Silence fraction
self.total_packets += 1;
if payload_len <= SILENCE_SIZE_THRESHOLD {
self.silence_packets += 1;
}
// Bitrate window
if now.duration_since(self.window_start) >= Duration::from_secs(BITRATE_WINDOW_SECS) {
self.window_start = now;
self.window_bytes = 0;
}
self.window_bytes += (MediaHeader::WIRE_SIZE + payload_len) as u64;
// Q-flag cadence
if header.has_quality() {
if let Some(last) = self.last_q_flag {
let interval = now.saturating_duration_since(last);
self.q_intervals.push_back(interval);
if self.q_intervals.len() > MAX_Q_INTERVALS {
self.q_intervals.pop_front();
}
}
self.last_q_flag = Some(now);
}
}
/// Compute legitimacy score ∈ [0, 1].
///
/// Higher = more legitimate. Returns `None` when insufficient samples
/// have been collected (< 20 packets).
pub fn legitimacy(&self) -> Option<f32> {
if self.total_packets < 20 {
return None;
}
let mut score = 1.0f32;
// 1. IAT CoV penalty
if let Some(cov) = self.iat_cov() {
if cov > 0.4 {
let penalty = ((cov - 0.4) / 0.6).min(1.0) * 0.25;
score -= penalty as f32;
}
}
// 2. Silence fraction penalty
let silence_fraction = self.silence_fraction();
if silence_fraction < 0.02 {
let penalty = ((0.02 - silence_fraction) / 0.02).min(1.0) * 0.25;
score -= penalty as f32;
} else if silence_fraction > 0.60 {
// Too much silence can also be suspicious (stuffed payloads)
let penalty = ((silence_fraction - 0.60) / 0.40).min(1.0) * 0.15;
score -= penalty as f32;
}
// 3. Bitrate deviation penalty
if let Some(ratio) = self.bitrate_ratio() {
if ratio > 1.20 {
let penalty = ((ratio - 1.20) / 0.80).min(1.0) * 0.25;
score -= penalty as f32;
}
}
// 4. Q-flag cadence penalty
if let Some(cv) = self.q_flag_cv() {
// High variability in Q-flag spacing = suspicious
if cv > 0.5 {
let penalty = ((cv - 0.5) / 0.5).min(1.0) * 0.15;
score -= penalty as f32;
}
} else {
// No Q flags seen at all — mildly suspicious after many packets
if self.total_packets > 100 {
score -= 0.10;
}
}
// 5. Payload-size bimodality bonus/penalty
if let Some(bimodality) = self.size_bimodality() {
// Bimodality score: 0 = unimodal, 1 = strongly bimodal
// Legitimate audio is bimodal (speech + silence)
if bimodality < 0.2 {
score -= 0.10;
}
}
Some(score.clamp(0.0, 1.0))
}
/// Map legitimacy score to a [`Verdict`].
pub fn verdict(&self) -> Option<Verdict> {
self.legitimacy().map(|s| {
if s >= 0.7 {
Verdict::Legitimate
} else if s >= 0.3 {
Verdict::Suspect
} else {
Verdict::Abusive
}
})
}
// ------------------------------------------------------------------
// Feature extractors
// ------------------------------------------------------------------
/// Coefficient of variation of inter-arrival times.
fn iat_cov(&self) -> Option<f64> {
if self.iat_samples.len() < 10 {
return None;
}
let mean = self
.iat_samples
.iter()
.map(|d| d.as_secs_f64())
.sum::<f64>()
/ self.iat_samples.len() as f64;
if mean == 0.0 {
return None;
}
let variance = self
.iat_samples
.iter()
.map(|d| {
let diff = d.as_secs_f64() - mean;
diff * diff
})
.sum::<f64>()
/ self.iat_samples.len() as f64;
let std = variance.sqrt();
Some(std / mean)
}
/// Fraction of packets that are silence / comfort-noise sized.
fn silence_fraction(&self) -> f64 {
if self.total_packets == 0 {
return 0.0;
}
self.silence_packets as f64 / self.total_packets as f64
}
/// Ratio of observed bitrate to nominal bitrate over the 30 s window.
fn bitrate_ratio(&self) -> Option<f64> {
let codec = self.declared_codec?;
let nominal_bps = codec.bitrate_bps() as f64;
if nominal_bps == 0.0 {
return None;
}
let observed_bps = self.window_bytes as f64 * 8.0 / BITRATE_WINDOW_SECS as f64;
Some(observed_bps / nominal_bps)
}
/// Coefficient of variation of Q-flag intervals.
fn q_flag_cv(&self) -> Option<f64> {
if self.q_intervals.len() < 3 {
return None;
}
let mean = self
.q_intervals
.iter()
.map(|d| d.as_secs_f64())
.sum::<f64>()
/ self.q_intervals.len() as f64;
if mean == 0.0 {
return None;
}
let variance = self
.q_intervals
.iter()
.map(|d| {
let diff = d.as_secs_f64() - mean;
diff * diff
})
.sum::<f64>()
/ self.q_intervals.len() as f64;
let std = variance.sqrt();
Some(std / mean)
}
/// Simple bimodality score based on a 2-bin histogram.
///
/// Splits payload sizes into "small" (≤ threshold) and "large" bins.
/// Returns a score in [0, 1] where 1 = strongly bimodal.
fn size_bimodality(&self) -> Option<f64> {
if self.size_samples.len() < 20 {
return None;
}
let small = self
.size_samples
.iter()
.filter(|&&s| s <= SILENCE_SIZE_THRESHOLD)
.count();
let large = self.size_samples.len() - small;
let total = self.size_samples.len() as f64;
let p_small = small as f64 / total;
let _p_large = large as f64 / total;
// Max bimodality when both bins are equally populated (~0.5 each)
let bimodality = 1.0 - (p_small - 0.5).abs() * 2.0;
Some(bimodality)
}
}
impl Default for AudioScorer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn audio_header(payload_len: usize, has_quality: bool) -> MediaHeader {
MediaHeader {
version: 2,
flags: if has_quality { 0x40 } else { 0 },
media_type: MediaType::Audio,
codec_id: CodecId::Opus24k,
stream_id: 0,
fec_ratio: 0,
seq: 0,
timestamp: 0,
fec_block: 0,
}
}
#[test]
fn audio_scorer_ignores_video() {
let mut scorer = AudioScorer::new();
let mut h = audio_header(100, false);
h.media_type = MediaType::Video;
scorer.observe(&h, 100, Instant::now());
assert_eq!(scorer.total_packets, 0);
}
#[test]
fn audio_scorer_counts_packets() {
let mut scorer = AudioScorer::new();
for i in 0..25 {
let h = audio_header(100, false);
scorer.observe(&h, 100, Instant::now() + Duration::from_millis(i * 20));
}
assert_eq!(scorer.total_packets, 25);
assert!(scorer.legitimacy().is_some());
}
#[test]
fn audio_scorer_legitimate_traffic() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
// Simulate 200 packets of legitimate audio:
// ~20 ms IAT, mixed speech (100 B) and silence (8 B), periodic Q flags.
for i in 0..200 {
let payload = if i % 3 == 0 { 8 } else { 100 };
let has_q = i % 10 == 0;
let h = audio_header(payload, has_q);
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
}
let leg = scorer.legitimacy().unwrap();
assert!(
leg >= 0.7,
"legitimate traffic should score ≥ 0.7, got {leg}"
);
assert_eq!(scorer.verdict(), Some(Verdict::Legitimate));
}
#[test]
fn audio_scorer_abusive_uniform_iat() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
// Uniform IAT (no jitter), all same size, no Q flags — tunnel-like
for i in 0..200 {
let h = audio_header(200, false);
scorer.observe(&h, 200, base + Duration::from_millis(i * 20));
}
let leg = scorer.legitimacy().unwrap();
assert!(
leg < 0.6,
"uniform tunnel-like traffic should score < 0.6, got {leg}"
);
}
#[test]
fn audio_scorer_abusive_no_silence() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
// No silence packets at all, very regular IAT
for i in 0..200 {
let h = audio_header(150, false);
scorer.observe(&h, 150, base + Duration::from_millis(i * 20));
}
let leg = scorer.legitimacy().unwrap();
assert!(
leg < 0.6,
"no-silence traffic should score < 0.6, got {leg}"
);
}
#[test]
fn audio_scorer_insufficient_samples() {
let scorer = AudioScorer::new();
assert_eq!(scorer.legitimacy(), None);
assert_eq!(scorer.verdict(), None);
}
#[test]
fn silence_fraction_computed_correctly() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
for i in 0..100 {
let payload = if i < 30 { 8 } else { 100 };
let h = audio_header(payload, false);
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
}
assert!((scorer.silence_fraction() - 0.30).abs() < 0.01);
}
#[test]
fn bitrate_ratio_saturates_when_no_codec() {
let scorer = AudioScorer::new();
assert_eq!(scorer.bitrate_ratio(), None);
}
#[test]
fn q_flag_cv_regular_spacing() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
for i in 0..50 {
let has_q = i % 5 == 0;
let h = audio_header(100, has_q);
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
}
let cv = scorer.q_flag_cv().unwrap();
assert!(
cv < 0.1,
"regular Q-flag spacing should have CV < 0.1, got {cv}"
);
}
#[test]
fn size_bimodality_for_mixed_traffic() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
for i in 0..100 {
let payload = if i % 2 == 0 { 8 } else { 120 };
let h = audio_header(payload, false);
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
}
let bim = scorer.size_bimodality().unwrap();
assert!(
bim > 0.8,
"perfectly mixed small/large should be highly bimodal, got {bim}"
);
}
#[test]
fn size_bimodality_for_uniform_traffic() {
let mut scorer = AudioScorer::new();
let base = Instant::now();
for i in 0..100 {
let h = audio_header(100, false);
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
}
let bim = scorer.size_bimodality().unwrap();
assert!(
bim < 0.3,
"uniform size traffic should be unimodal, got {bim}"
);
}
}

View File

@@ -32,10 +32,7 @@ pub struct AuthenticatedClient {
/// ///
/// Calls `POST {auth_url}` with `{ "token": "..." }`. /// Calls `POST {auth_url}` with `{ "token": "..." }`.
/// Returns the client identity if valid, or an error string. /// Returns the client identity if valid, or an error string.
pub async fn validate_token( pub async fn validate_token(auth_url: &str, token: &str) -> Result<AuthenticatedClient, String> {
auth_url: &str,
token: &str,
) -> Result<AuthenticatedClient, String> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5)) .timeout(std::time::Duration::from_secs(5))
.build() .build()

View File

@@ -83,7 +83,12 @@ impl CallRegistry {
} }
/// Create a new pending call. Returns the call_id. /// Create a new pending call. Returns the call_id.
pub fn create_call(&mut self, call_id: String, caller_fp: String, callee_fp: String) -> &DirectCall { pub fn create_call(
&mut self,
call_id: String,
caller_fp: String,
callee_fp: String,
) -> &DirectCall {
let call = DirectCall { let call = DirectCall {
call_id: call_id.clone(), call_id: call_id.clone(),
caller_fingerprint: caller_fp, caller_fingerprint: caller_fp,
@@ -189,7 +194,12 @@ impl CallRegistry {
} }
/// Transition to Active state. /// Transition to Active state.
pub fn set_active(&mut self, call_id: &str, mode: wzp_proto::CallAcceptMode, room: String) -> bool { pub fn set_active(
&mut self,
call_id: &str,
mode: wzp_proto::CallAcceptMode,
room: String,
) -> bool {
if let Some(call) = self.calls.get_mut(call_id) { if let Some(call) = self.calls.get_mut(call_id) {
if call.state == DirectCallState::Pending || call.state == DirectCallState::Ringing { if call.state == DirectCallState::Pending || call.state == DirectCallState::Ringing {
call.state = DirectCallState::Active; call.state = DirectCallState::Active;
@@ -213,7 +223,8 @@ impl CallRegistry {
/// Find active/pending calls involving a fingerprint. /// Find active/pending calls involving a fingerprint.
pub fn calls_for_fingerprint(&self, fp: &str) -> Vec<&DirectCall> { pub fn calls_for_fingerprint(&self, fp: &str) -> Vec<&DirectCall> {
self.calls.values() self.calls
.values()
.filter(|c| { .filter(|c| {
c.state != DirectCallState::Ended c.state != DirectCallState::Ended
&& (c.caller_fingerprint == fp || c.callee_fingerprint == fp) && (c.caller_fingerprint == fp || c.callee_fingerprint == fp)
@@ -236,22 +247,25 @@ impl CallRegistry {
/// Returns call IDs of expired calls. /// Returns call IDs of expired calls.
pub fn expire_stale(&mut self, timeout: Duration) -> Vec<DirectCall> { pub fn expire_stale(&mut self, timeout: Duration) -> Vec<DirectCall> {
let now = Instant::now(); let now = Instant::now();
let expired: Vec<String> = self.calls.iter() let expired: Vec<String> = self
.calls
.iter()
.filter(|(_, c)| { .filter(|(_, c)| {
c.state == DirectCallState::Pending c.state == DirectCallState::Pending && now.duration_since(c.created_at) > timeout
&& now.duration_since(c.created_at) > timeout
}) })
.map(|(id, _)| id.clone()) .map(|(id, _)| id.clone())
.collect(); .collect();
expired.into_iter() expired
.into_iter()
.filter_map(|id| self.calls.remove(&id)) .filter_map(|id| self.calls.remove(&id))
.collect() .collect()
} }
/// Number of active (non-ended) calls. /// Number of active (non-ended) calls.
pub fn active_count(&self) -> usize { pub fn active_count(&self) -> usize {
self.calls.values() self.calls
.values()
.filter(|c| c.state != DirectCallState::Ended) .filter(|c| c.state != DirectCallState::Ended)
.count() .count()
} }
@@ -270,9 +284,16 @@ mod tests {
assert!(reg.set_ringing("c1")); assert!(reg.set_ringing("c1"));
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Ringing); assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Ringing);
assert!(reg.set_active("c1", wzp_proto::CallAcceptMode::AcceptGeneric, "_call:c1".into())); assert!(reg.set_active(
"c1",
wzp_proto::CallAcceptMode::AcceptGeneric,
"_call:c1".into()
));
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Active); assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Active);
assert_eq!(reg.get("c1").unwrap().room_name.as_deref(), Some("_call:c1")); assert_eq!(
reg.get("c1").unwrap().room_name.as_deref(),
Some("_call:c1")
);
let ended = reg.end_call("c1").unwrap(); let ended = reg.end_call("c1").unwrap();
assert_eq!(ended.state, DirectCallState::Ended); assert_eq!(ended.state, DirectCallState::Ended);
@@ -329,10 +350,7 @@ mod tests {
// Both addrs are independently readable — the relay uses // Both addrs are independently readable — the relay uses
// them to cross-wire peer_direct_addr in CallSetup. // them to cross-wire peer_direct_addr in CallSetup.
let c = reg.get("c1").unwrap(); let c = reg.get("c1").unwrap();
assert_eq!( assert_eq!(c.caller_reflexive_addr.as_deref(), Some("192.0.2.1:4433"));
c.caller_reflexive_addr.as_deref(),
Some("192.0.2.1:4433")
);
assert_eq!( assert_eq!(
c.callee_reflexive_addr.as_deref(), c.callee_reflexive_addr.as_deref(),
Some("198.51.100.9:4433") Some("198.51.100.9:4433")

View File

@@ -145,7 +145,10 @@ pub struct RelayInfo {
} }
/// Load config from path, or create a personalized example config if it doesn't exist. /// Load config from path, or create a personalized example config if it doesn't exist.
pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<RelayConfig, anyhow::Error> { pub fn load_or_create_config(
path: &str,
info: Option<&RelayInfo>,
) -> Result<RelayConfig, anyhow::Error> {
let p = std::path::Path::new(path); let p = std::path::Path::new(path);
if p.exists() { if p.exists() {
return load_config(path); return load_config(path);
@@ -164,7 +167,9 @@ pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<Rel
/// Generate an example TOML config, personalized with this relay's info if available. /// Generate an example TOML config, personalized with this relay's info if available.
fn generate_example_config(info: Option<&RelayInfo>) -> String { fn generate_example_config(info: Option<&RelayInfo>) -> String {
let listen = info.map(|i| i.listen_addr.as_str()).unwrap_or("0.0.0.0:4433"); let listen = info
.map(|i| i.listen_addr.as_str())
.unwrap_or("0.0.0.0:4433");
let peer_example = if let Some(i) = info { let peer_example = if let Some(i) = info {
let ip = i.public_ip.as_deref().unwrap_or("this-relay-ip"); let ip = i.public_ip.as_deref().unwrap_or("this-relay-ip");
format!( format!(

View File

@@ -0,0 +1,544 @@
//! Relay conformance metering — Tier A/B/C/D/E enforcement.
//!
//! Each participant gets a [`ConformanceMeter`] that tracks per-second
//! traffic against the declared codec's nominal bitrate ceiling.
//! Violations are logged and counted but do **not** drop packets
//! (observe-only mode).
use std::collections::VecDeque;
use std::time::{Duration, Instant};
use wzp_proto::{CodecId, MediaHeader};
/// Rolling window size for timestamp-drift detection (Tier C).
const DRIFT_WINDOW_SIZE: usize = 200;
/// Kinds of conformance violation detected by the relay.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Violation {
/// Cumulative bitrate in the current 1 s window exceeds the Tier A ceiling.
BitrateExceeded,
/// Packet rate exceeds the per-codec safety limit (Tier B).
PacketRateExceeded,
/// Timestamp jumped backwards or forwards suspiciously (Tier C).
TimestampDrift,
/// Sustained payload size exceeds 2× the typical bound for the declared codec (Tier D).
PayloadSizeExceeded,
/// Per-session token-bucket rate cap exceeded (Tier E).
RateCapExceeded,
}
/// Error type returned when a [`TokenBucket`] does not hold enough tokens.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TokenExhausted;
/// Simple token bucket for per-session rate capping (Tier E).
///
/// Tokens represent bytes. The bucket refills at `refill_per_sec` bytes per
/// second, up to `capacity`. A packet is allowed only if the bucket holds
/// enough tokens for its size.
pub struct TokenBucket {
capacity: u64,
tokens: f64,
refill_per_sec: u64,
last_refill: Instant,
}
impl TokenBucket {
/// Create a new bucket with the given byte capacity and refill rate.
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self {
capacity,
tokens: capacity as f64,
refill_per_sec,
last_refill: Instant::now(),
}
}
/// Per-session audio cap: 256 kbps with 30 s @ 2× burst.
/// Capacity = 30 s × 64 KB/s = 1_920_000 bytes.
pub fn for_audio_session() -> Self {
let refill_per_sec = 256_000 / 8; // 32_000 bytes/sec
let capacity = refill_per_sec * 30 * 2; // 1_920_000 bytes
Self::new(capacity, refill_per_sec)
}
/// Attempt to consume `bytes` from the bucket.
///
/// Refills based on elapsed time since the last call, then deducts the
/// cost. Returns `Ok(())` if enough tokens were available,
/// `Err(TokenExhausted)` otherwise.
pub fn try_consume(&mut self, bytes: u64, now: Instant) -> Result<(), TokenExhausted> {
let elapsed = now.duration_since(self.last_refill);
self.last_refill = now;
self.tokens += elapsed.as_secs_f64() * self.refill_per_sec as f64;
if self.tokens > self.capacity as f64 {
self.tokens = self.capacity as f64;
}
if self.tokens >= bytes as f64 {
self.tokens -= bytes as f64;
Ok(())
} else {
Err(TokenExhausted)
}
}
}
/// Per-participant traffic conformance meter.
pub struct ConformanceMeter {
window_start: Instant,
bytes_in_window: u64,
packets_in_window: u64,
/// Rolling (seq, timestamp) pairs for drift detection.
drift_window: VecDeque<(u32, u32)>,
/// EWMA of payload size for Tier D sanity checks.
ewma_payload_size: f64,
/// Optional token bucket for Tier E per-session rate cap.
token_bucket: Option<TokenBucket>,
}
impl ConformanceMeter {
pub fn new() -> Self {
Self {
window_start: Instant::now(),
bytes_in_window: 0,
packets_in_window: 0,
drift_window: VecDeque::with_capacity(DRIFT_WINDOW_SIZE),
ewma_payload_size: 0.0,
token_bucket: None,
}
}
/// Create a meter with a Tier E token bucket for per-session rate capping.
pub fn with_token_bucket(bucket: TokenBucket) -> Self {
let mut meter = Self::new();
meter.token_bucket = Some(bucket);
meter
}
/// Inspect an incoming media packet and accumulate it against the
/// current 1-second window. Returns [`Err(Violation)`] when a limit
/// is crossed.
pub fn observe(
&mut self,
header: &MediaHeader,
payload_len: usize,
now: Instant,
) -> Result<(), Violation> {
// Roll the window forward if a second has elapsed.
if now.duration_since(self.window_start) >= Duration::from_secs(1) {
self.window_start = now;
self.bytes_in_window = 0;
self.packets_in_window = 0;
}
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
self.bytes_in_window += packet_size;
self.packets_in_window += 1;
// Tier A — bitrate ceiling.
let ceiling = ceiling_bps(header.codec_id);
let max_bytes_per_sec = ceiling / 8;
if self.bytes_in_window > max_bytes_per_sec {
return Err(Violation::BitrateExceeded);
}
// Tier B — packet-rate ceiling.
let max_pps = max_pps(header.codec_id);
let pps_threshold = (max_pps as f32 * 1.5) as u64;
if self.packets_in_window > pps_threshold {
return Err(Violation::PacketRateExceeded);
}
// Tier C — timestamp drift.
self.drift_window.push_back((header.seq, header.timestamp));
if self.drift_window.len() > DRIFT_WINDOW_SIZE {
self.drift_window.pop_front();
}
if self.drift_window.len() >= 2 {
let (first_seq, first_ts) = self.drift_window.front().copied().unwrap();
let (last_seq, last_ts) = self.drift_window.back().copied().unwrap();
let ds = last_seq.wrapping_sub(first_seq) as f64;
let dt = last_ts.wrapping_sub(first_ts) as f64;
if ds > 0.0 {
let avg_ms_per_packet = dt / ds;
let frame_ms = header.codec_id.frame_duration_ms() as f64;
let min_ratio = frame_ms * 0.5;
let max_ratio = frame_ms * 2.0;
if avg_ms_per_packet < min_ratio || avg_ms_per_packet > max_ratio {
return Err(Violation::TimestampDrift);
}
}
}
// Tier D — payload-size sanity (EWMA).
let alpha = 0.05; // ~20-packet smoothing
self.ewma_payload_size =
alpha * payload_len as f64 + (1.0 - alpha) * self.ewma_payload_size;
let bound = payload_size_bound(header.codec_id);
if self.ewma_payload_size > (bound * 2) as f64 {
return Err(Violation::PayloadSizeExceeded);
}
// Tier E — per-session token-bucket rate cap.
if let Some(ref mut bucket) = self.token_bucket {
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
if bucket.try_consume(packet_size, now).is_err() {
return Err(Violation::RateCapExceeded);
}
}
Ok(())
}
}
impl Default for ConformanceMeter {
fn default() -> Self {
Self::new()
}
}
/// Compute the Tier A bitrate ceiling for a given codec.
///
/// Formula:
/// nominal_bitrate * 3 (FEC 2.0 overhead) * 115 / 100 (15% safety margin)
/// with a floor of 2 kbps.
pub fn ceiling_bps(codec: CodecId) -> u64 {
let nominal = codec.bitrate_bps() as u64;
(nominal * 3 * 115 / 100).max(2_000)
}
/// Compute the Tier B packet-rate ceiling for a given codec.
///
/// Formula:
/// 1000 / frame_duration_ms * 3 (FEC overhead factor)
pub fn max_pps(codec: CodecId) -> u32 {
let fd = codec.frame_duration_ms() as u32;
if fd == 0 {
return 0;
}
(1000 / fd) * 3
}
/// Typical per-codec payload size bound in bytes (Tier D).
///
/// These are empirical upper bounds for a single audio frame at the codec's
/// nominal configuration. The EWMA must not exceed 2× this value.
pub fn payload_size_bound(codec: CodecId) -> usize {
match codec {
CodecId::Opus64k => 320,
CodecId::Opus48k => 240,
CodecId::Opus32k => 200,
CodecId::Opus24k => 160,
CodecId::Opus16k => 100,
CodecId::Opus6k => 90,
CodecId::Codec2_3200 => 30,
CodecId::Codec2_1200 => 30,
CodecId::ComfortNoise => 16,
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => 1400,
}
}
#[cfg(test)]
mod tests {
use super::*;
use wzp_proto::MediaType;
fn make_header(codec_id: CodecId) -> MediaHeader {
MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id,
seq: 0,
timestamp: 0,
fec_block: 0,
stream_id: 0,
fec_ratio: 0,
}
}
fn make_header_with_seq_ts(codec_id: CodecId, seq: u32, timestamp: u32) -> MediaHeader {
MediaHeader {
version: 2,
flags: 0,
media_type: MediaType::Audio,
codec_id,
seq,
timestamp,
fec_block: 0,
stream_id: 0,
fec_ratio: 0,
}
}
#[test]
fn bitrate_exceeded_for_opus24k() {
let mut meter = ConformanceMeter::new();
let header = make_header(CodecId::Opus24k);
// Ceiling for Opus24k = 24_000 * 3 * 115 / 100 = 82_800 bps
// = 10_350 bytes/sec. 1 MB/s = 125_000 bytes/packet will blow past
// that in a single packet.
let now = Instant::now();
let result = meter.observe(&header, 1_000_000, now);
assert_eq!(result, Err(Violation::BitrateExceeded));
}
#[test]
fn small_packets_stay_within_ceiling() {
let mut meter = ConformanceMeter::new();
let header = make_header(CodecId::Opus24k);
// Ceiling = 82_800 bps = 10_350 bytes/sec.
// Each packet = 16-byte header + 80 bytes = 96 bytes.
// 100 packets = 9_600 bytes < 10_350.
let now = Instant::now();
for _ in 0..100 {
assert!(meter.observe(&header, 80, now).is_ok());
}
}
#[test]
fn window_resets_after_one_second() {
let mut meter = ConformanceMeter::new();
let header = make_header(CodecId::Opus24k);
// Fill the window to just under the limit.
// Use 300-byte payloads (under Tier D 2× bound of 320 for Opus24k).
let t0 = Instant::now();
for _ in 0..32 {
assert!(meter.observe(&header, 300, t0).is_ok());
}
// 32 * (header wire size + 300) ≈ 32 * 316 = 10_112 bytes < 10_350
// Same packets 1.1 seconds later should be fine because the window
// rolls over.
let t1 = t0 + Duration::from_millis(1_100);
for _ in 0..32 {
assert!(meter.observe(&header, 300, t1).is_ok());
}
}
#[test]
fn ceiling_bps_floor() {
// ComfortNoise has 0 nominal bitrate, so the floor kicks in.
assert_eq!(ceiling_bps(CodecId::ComfortNoise), 2_000);
}
// ------------------------------------------------------------------
// Tier B — packet rate
// ------------------------------------------------------------------
#[test]
fn packet_rate_exceeded() {
let mut meter = ConformanceMeter::new();
// Opus24k: max_pps = 1000/20 * 3 = 150. Threshold = 150 * 1.5 = 225.
let header = make_header(CodecId::Opus24k);
let now = Instant::now();
for _ in 0..225 {
assert!(meter.observe(&header, 10, now).is_ok());
}
// 226th packet should trip the limit.
assert_eq!(
meter.observe(&header, 10, now),
Err(Violation::PacketRateExceeded)
);
}
#[test]
fn packet_rate_within_limit() {
let mut meter = ConformanceMeter::new();
// Opus6k: max_pps = 1000/40 * 3 = 75. Threshold = 75 * 1.5 = 112.
// Use 0-byte payload so bitrate ceiling (2_587 bytes/sec) is not the
// limiting factor. 112 packets × 16 bytes = 1_792 bytes < 2_587.
let header = make_header(CodecId::Opus6k);
let now = Instant::now();
for _ in 0..112 {
assert!(meter.observe(&header, 0, now).is_ok());
}
}
// ------------------------------------------------------------------
// Tier C — timestamp drift
// ------------------------------------------------------------------
#[test]
fn timestamp_drift_detected_when_too_fast() {
let mut meter = ConformanceMeter::new();
// Opus24k frame_duration = 20 ms.
// Acceptable range: [10, 40] ms per packet.
// Send packets with timestamp advancing by 5 ms each (too fast).
let now = Instant::now();
let mut drift_seen = false;
for i in 0..200 {
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 5);
match meter.observe(&header, 10, now) {
Ok(()) => {}
Err(Violation::TimestampDrift) => drift_seen = true,
Err(other) => panic!("unexpected violation: {other:?}"),
}
}
assert!(drift_seen, "expected TimestampDrift to be detected");
}
#[test]
fn timestamp_drift_detected_when_too_slow() {
let mut meter = ConformanceMeter::new();
// Opus24k frame_duration = 20 ms.
// Acceptable range: [10, 40] ms per packet.
// Send packets with timestamp advancing by 50 ms each (too slow).
let now = Instant::now();
let mut drift_seen = false;
for i in 0..200 {
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 50);
match meter.observe(&header, 10, now) {
Ok(()) => {}
Err(Violation::TimestampDrift) => drift_seen = true,
Err(other) => panic!("unexpected violation: {other:?}"),
}
}
assert!(drift_seen, "expected TimestampDrift to be detected");
}
#[test]
fn timestamp_normal_no_drift() {
let mut meter = ConformanceMeter::new();
// Opus24k frame_duration = 20 ms.
// Send 200 packets with timestamp advancing by exactly 20 ms each.
let now = Instant::now();
for i in 0..200 {
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 20);
assert!(meter.observe(&header, 10, now).is_ok());
}
}
#[test]
fn timestamp_drift_not_checked_before_two_packets() {
let mut meter = ConformanceMeter::new();
let now = Instant::now();
// Single packet with wild timestamp — should not trigger drift.
let header = make_header_with_seq_ts(CodecId::Opus24k, 0, 999_999);
assert!(meter.observe(&header, 10, now).is_ok());
}
// ------------------------------------------------------------------
// Tier D — payload-size sanity
// ------------------------------------------------------------------
#[test]
fn conformance_tier_d() {
let mut meter = ConformanceMeter::new();
let header = make_header(CodecId::Codec2_1200);
let now = Instant::now();
// Codec2_1200 bound = 30 bytes. 2× bound = 60 bytes.
// Feed 1400-byte payloads — EWMA should cross 60 within a few packets.
let mut flagged = false;
for _ in 0..200 {
if meter.observe(&header, 1400, now).is_err() {
flagged = true;
break;
}
}
assert!(
flagged,
"expected PayloadSizeExceeded for 1400-byte Codec2_1200 payloads"
);
}
#[test]
fn payload_size_normal_stays_within_bound() {
let mut meter = ConformanceMeter::new();
let header = make_header(CodecId::Opus24k);
let now = Instant::now();
// Opus24k bound = 160 bytes. 2× bound = 320 bytes.
// Feed 150-byte payloads — well within the 2× limit.
// Limit to 10 packets so the 1-second bitrate window (10_350 bytes)
// is not exhausted: 10 * (16 + 150) = 1_660 < 10_350.
for _ in 0..10 {
assert!(
meter.observe(&header, 150, now).is_ok(),
"150-byte Opus24k payloads should stay within Tier D limit"
);
}
}
// ------------------------------------------------------------------
// Tier E — token-bucket rate cap
// ------------------------------------------------------------------
#[test]
fn token_bucket_small_burst_ok() {
let mut bucket = TokenBucket::new(100_000, 32_000);
let now = Instant::now();
// 50 KB burst fits inside 100 KB capacity.
assert!(bucket.try_consume(50_000, now).is_ok());
}
#[test]
fn token_bucket_large_burst_fails() {
let mut bucket = TokenBucket::new(100_000, 32_000);
let now = Instant::now();
// 1 MB exceeds 100 KB capacity.
assert!(bucket.try_consume(1_000_000, now).is_err());
}
#[test]
fn token_bucket_refills_over_time() {
let mut bucket = TokenBucket::new(100_000, 32_000);
let t0 = Instant::now();
// Drain the bucket.
assert!(bucket.try_consume(100_000, t0).is_ok());
// Immediately try again — should fail.
assert!(bucket.try_consume(10_000, t0).is_err());
// Wait 1 second — bucket refills 32_000 bytes.
let t1 = t0 + Duration::from_secs(1);
assert!(bucket.try_consume(30_000, t1).is_ok());
// 40_000 is more than the 32_000 refilled.
assert!(bucket.try_consume(40_000, t1).is_err());
}
#[test]
fn token_bucket_sustained_rate_balanced() {
let mut bucket = TokenBucket::new(1_000_000, 32_000);
let t0 = Instant::now();
// Send 32 KB every second for 5 seconds — exactly at refill rate.
// The bucket should never empty because each second it refills
// exactly what was consumed.
for i in 0..5 {
let t = t0 + Duration::from_secs(i);
assert!(
bucket.try_consume(32_000, t).is_ok(),
"32 KB/s sustained should stay within bucket limit"
);
}
}
#[test]
fn conformance_tier_e_integration() {
// Use Opus64k (high bitrate ceiling + high payload bound) so Tiers
// A/B/D never fire on the small bursts used here. Only Tier E.
let mut meter = ConformanceMeter::with_token_bucket(TokenBucket::new(1_000, 500));
let header = make_header(CodecId::Opus64k);
let now = Instant::now();
// Two 500-byte (wire) packets = 1_000 bytes — exactly the bucket cap.
assert!(
meter
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
.is_ok()
);
assert!(
meter
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
.is_ok()
);
// Third packet exceeds the 1_000-byte cap.
let result = meter.observe(&header, 10, now);
assert_eq!(result, Err(Violation::RateCapExceeded));
}
}

View File

@@ -25,16 +25,13 @@ pub struct Event {
pub src: Option<String>, pub src: Option<String>,
/// Packet sequence number. /// Packet sequence number.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub seq: Option<u16>, pub seq: Option<u32>,
/// Codec identifier. /// Codec identifier.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub codec: Option<String>, pub codec: Option<String>,
/// FEC block ID. /// FEC block ID (low byte) and symbol index (high byte).
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub fec_block: Option<u8>, pub fec_block: Option<u16>,
/// FEC symbol index.
#[serde(skip_serializing_if = "Option::is_none")]
pub fec_sym: Option<u8>,
/// Is FEC repair packet. /// Is FEC repair packet.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub repair: Option<bool>, pub repair: Option<bool>,
@@ -60,7 +57,9 @@ pub struct Event {
impl Event { impl Event {
fn now() -> String { fn now() -> String {
chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.6fZ").to_string() chrono::Utc::now()
.format("%Y-%m-%dT%H:%M:%S%.6fZ")
.to_string()
} }
/// Create a minimal event with just type and timestamp. /// Create a minimal event with just type and timestamp.
@@ -73,7 +72,6 @@ impl Event {
seq: None, seq: None,
codec: None, codec: None,
fec_block: None, fec_block: None,
fec_sym: None,
repair: None, repair: None,
len: None, len: None,
to_count: None, to_count: None,
@@ -85,33 +83,59 @@ impl Event {
} }
/// Set room. /// Set room.
pub fn room(mut self, room: &str) -> Self { self.room = Some(room.to_string()); self } pub fn room(mut self, room: &str) -> Self {
self.room = Some(room.to_string());
self
}
/// Set source. /// Set source.
pub fn src(mut self, src: &str) -> Self { self.src = Some(src.to_string()); self } pub fn src(mut self, src: &str) -> Self {
self.src = Some(src.to_string());
self
}
/// Set packet header fields from a MediaPacket. /// Set packet header fields from a MediaPacket.
pub fn packet(mut self, pkt: &wzp_proto::MediaPacket) -> Self { pub fn packet(mut self, pkt: &wzp_proto::MediaPacket) -> Self {
self.seq = Some(pkt.header.seq); self.seq = Some(pkt.header.seq);
self.codec = Some(format!("{:?}", pkt.header.codec_id)); self.codec = Some(format!("{:?}", pkt.header.codec_id));
self.fec_block = Some(pkt.header.fec_block); self.fec_block = Some(pkt.header.fec_block);
self.fec_sym = Some(pkt.header.fec_symbol); self.repair = Some(pkt.header.is_repair());
self.repair = Some(pkt.header.is_repair);
self.len = Some(pkt.payload.len()); self.len = Some(pkt.payload.len());
self self
} }
/// Set seq only (when full packet not available). /// Set seq only (when full packet not available).
pub fn seq(mut self, seq: u16) -> Self { self.seq = Some(seq); self } pub fn seq(mut self, seq: u32) -> Self {
self.seq = Some(seq);
self
}
/// Set payload length. /// Set payload length.
pub fn len(mut self, len: usize) -> Self { self.len = Some(len); self } pub fn len(mut self, len: usize) -> Self {
self.len = Some(len);
self
}
/// Set recipient count. /// Set recipient count.
pub fn to_count(mut self, n: usize) -> Self { self.to_count = Some(n); self } pub fn to_count(mut self, n: usize) -> Self {
self.to_count = Some(n);
self
}
/// Set peer label. /// Set peer label.
pub fn peer(mut self, peer: &str) -> Self { self.peer = Some(peer.to_string()); self } pub fn peer(mut self, peer: &str) -> Self {
self.peer = Some(peer.to_string());
self
}
/// Set drop reason. /// Set drop reason.
pub fn reason(mut self, reason: &str) -> Self { self.reason = Some(reason.to_string()); self } pub fn reason(mut self, reason: &str) -> Self {
self.reason = Some(reason.to_string());
self
}
/// Set presence action. /// Set presence action.
pub fn action(mut self, action: &str) -> Self { self.action = Some(action.to_string()); self } pub fn action(mut self, action: &str) -> Self {
self.action = Some(action.to_string());
self
}
/// Set participant count. /// Set participant count.
pub fn participants(mut self, n: usize) -> Self { self.participants = Some(n); self } pub fn participants(mut self, n: usize) -> Self {
self.participants = Some(n);
self
}
} }
/// Handle for emitting events. Cheap to clone. /// Handle for emitting events. Cheap to clone.
@@ -181,8 +205,12 @@ async fn writer_task(path: PathBuf, mut rx: mpsc::UnboundedReceiver<Event>) {
while let Some(event) = rx.recv().await { while let Some(event) = rx.recv().await {
match serde_json::to_string(&event) { match serde_json::to_string(&event) {
Ok(json) => { Ok(json) => {
if writer.write_all(json.as_bytes()).await.is_err() { break; } if writer.write_all(json.as_bytes()).await.is_err() {
if writer.write_all(b"\n").await.is_err() { break; } break;
}
if writer.write_all(b"\n").await.is_err() {
break;
}
count += 1; count += 1;
// Flush every 100 events // Flush every 100 events
if count % 100 == 0 { if count % 100 == 0 {

View File

@@ -11,11 +11,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::Bytes; use bytes::Bytes;
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use wzp_proto::{MediaTransport, SignalMessage}; use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
use wzp_transport::QuinnTransport; use wzp_transport::QuinnTransport;
use crate::config::{PeerConfig, TrustedConfig}; use crate::config::{PeerConfig, TrustedConfig};
@@ -56,13 +56,14 @@ impl Deduplicator {
} }
/// Returns true if this packet is a duplicate (already seen within TTL). /// Returns true if this packet is a duplicate (already seen within TTL).
fn is_dup(&mut self, room_hash: &[u8; 8], seq: u16, extra: u64) -> bool { fn is_dup(&mut self, room_hash: &[u8; 8], seq: u32, extra: u64) -> bool {
let key = u64::from_be_bytes(*room_hash) ^ (seq as u64) ^ extra; let key = u64::from_be_bytes(*room_hash) ^ (seq as u64) ^ extra;
let now = Instant::now(); let now = Instant::now();
// Periodic cleanup (every ~256 packets) // Periodic cleanup (every ~256 packets)
if self.entries.len() > 256 { if self.entries.len() > 256 {
self.entries.retain(|_, ts| now.duration_since(*ts) < self.ttl); self.entries
.retain(|_, ts| now.duration_since(*ts) < self.ttl);
} }
if let Some(ts) = self.entries.get(&key) { if let Some(ts) = self.entries.get(&key) {
@@ -215,7 +216,10 @@ impl FederationManager {
pub async fn broadcast_signal(&self, msg: &wzp_proto::SignalMessage) -> usize { pub async fn broadcast_signal(&self, msg: &wzp_proto::SignalMessage) -> usize {
let peers: Vec<(String, String, Arc<QuinnTransport>)> = { let peers: Vec<(String, String, Arc<QuinnTransport>)> = {
let links = self.peer_links.lock().await; let links = self.peer_links.lock().await;
links.iter().map(|(fp, l)| (fp.clone(), l.label.clone(), l.transport.clone())).collect() links
.iter()
.map(|(fp, l)| (fp.clone(), l.label.clone(), l.transport.clone()))
.collect()
}; // lock released }; // lock released
let mut count = 0; let mut count = 0;
for (fp, label, transport) in &peers { for (fp, label, transport) in &peers {
@@ -300,9 +304,10 @@ impl FederationManager {
return Some(room.to_string()); return Some(room.to_string());
} }
// Hashed match (desktop clients hash room names for SNI privacy) // Hashed match (desktop clients hash room names for SNI privacy)
self.global_rooms.iter().find(|name| { self.global_rooms
wzp_crypto::hash_room_name(name) == room .iter()
}).map(|s| s.to_string()) .find(|name| wzp_crypto::hash_room_name(name) == room)
.map(|s| s.to_string())
} }
/// Get the canonical federation room hash for a room. /// Get the canonical federation room hash for a room.
@@ -371,7 +376,10 @@ impl FederationManager {
/// Get all remote participants for a room from all peer links. /// Get all remote participants for a room from all peer links.
/// Deduplicates by fingerprint (same participant may appear via multiple links). /// Deduplicates by fingerprint (same participant may appear via multiple links).
pub async fn get_remote_participants(&self, room: &str) -> Vec<wzp_proto::packet::RoomParticipant> { pub async fn get_remote_participants(
&self,
room: &str,
) -> Vec<wzp_proto::packet::RoomParticipant> {
let canonical = self.resolve_global_room(room); let canonical = self.resolve_global_room(room);
let links = self.peer_links.lock().await; let links = self.peer_links.lock().await;
let mut result = Vec::new(); let mut result = Vec::new();
@@ -407,11 +415,21 @@ impl FederationManager {
/// the other room-tagged helpers and for future per-room-name logging /// the other room-tagged helpers and for future per-room-name logging
/// or rate limiting; the body currently forwards on `room_hash` alone /// or rate limiting; the body currently forwards on `room_hash` alone
/// because that's what the wire format carries. /// because that's what the wire format carries.
pub async fn forward_to_peers(&self, _room_name: &str, room_hash: &[u8; 8], media_data: &Bytes) { pub async fn forward_to_peers(
&self,
_room_name: &str,
room_hash: &[u8; 8],
media_data: &Bytes,
) {
let peers: Vec<(String, Arc<QuinnTransport>)> = { let peers: Vec<(String, Arc<QuinnTransport>)> = {
let links = self.peer_links.lock().await; let links = self.peer_links.lock().await;
if links.is_empty() { return; } if links.is_empty() {
links.values().map(|l| (l.label.clone(), l.transport.clone())).collect() return;
}
links
.values()
.map(|l| (l.label.clone(), l.transport.clone()))
.collect()
}; // lock released }; // lock released
for (label, transport) in &peers { for (label, transport) in &peers {
@@ -420,8 +438,10 @@ impl FederationManager {
tagged.extend_from_slice(media_data); tagged.extend_from_slice(media_data);
match transport.send_raw_datagram(&tagged) { match transport.send_raw_datagram(&tagged) {
Ok(()) => { Ok(()) => {
self.metrics.federation_packets_forwarded self.metrics
.with_label_values(&[label, "out"]).inc(); .federation_packets_forwarded
.with_label_values(&[label, "out"])
.inc();
} }
Err(e) => warn!(peer = %label, "federation send error: {e}"), Err(e) => warn!(peer = %label, "federation send error: {e}"),
} }
@@ -431,20 +451,25 @@ impl FederationManager {
// ── Trust verification (kept from previous implementation) ── // ── Trust verification (kept from previous implementation) ──
pub fn find_peer_by_fingerprint(&self, fp: &str) -> Option<&PeerConfig> { pub fn find_peer_by_fingerprint(&self, fp: &str) -> Option<&PeerConfig> {
self.peers.iter().find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp)) self.peers
.iter()
.find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp))
} }
pub fn find_peer_by_addr(&self, addr: SocketAddr) -> Option<&PeerConfig> { pub fn find_peer_by_addr(&self, addr: SocketAddr) -> Option<&PeerConfig> {
let addr_ip = addr.ip(); let addr_ip = addr.ip();
self.peers.iter().find(|p| { self.peers.iter().find(|p| {
p.url.parse::<SocketAddr>() p.url
.parse::<SocketAddr>()
.map(|sa| sa.ip() == addr_ip) .map(|sa| sa.ip() == addr_ip)
.unwrap_or(false) .unwrap_or(false)
}) })
} }
pub fn find_trusted_by_fingerprint(&self, fp: &str) -> Option<&TrustedConfig> { pub fn find_trusted_by_fingerprint(&self, fp: &str) -> Option<&TrustedConfig> {
self.trusted.iter().find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp)) self.trusted
.iter()
.find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp))
} }
pub fn check_inbound_trust(&self, addr: SocketAddr, hello_fp: &str) -> Option<String> { pub fn check_inbound_trust(&self, addr: SocketAddr, hello_fp: &str) -> Option<String> {
@@ -452,7 +477,12 @@ impl FederationManager {
return Some(peer.label.clone().unwrap_or_else(|| peer.url.clone())); return Some(peer.label.clone().unwrap_or_else(|| peer.url.clone()));
} }
if let Some(trusted) = self.find_trusted_by_fingerprint(hello_fp) { if let Some(trusted) = self.find_trusted_by_fingerprint(hello_fp) {
return Some(trusted.label.clone().unwrap_or_else(|| hello_fp[..16].to_string())); return Some(
trusted
.label
.clone()
.unwrap_or_else(|| hello_fp[..16].to_string()),
);
} }
None None
} }
@@ -471,7 +501,8 @@ pub async fn run_federation_media_egress(
if count == 1 || count % 250 == 0 { if count == 1 || count % 250 == 0 {
info!(room = %out.room_name, count, "federation egress: forwarding media"); info!(room = %out.room_name, count, "federation egress: forwarding media");
} }
fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data).await; fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data)
.await;
} }
info!(total = count, "federation egress task ended"); info!(total = count, "federation egress task ended");
} }
@@ -489,7 +520,11 @@ async fn run_room_event_dispatcher(
if fm.is_global_room(&room) { if fm.is_global_room(&room) {
let participants = fm.room_mgr.local_participant_list(&room); let participants = fm.room_mgr.local_participant_list(&room);
info!(room = %room, count = participants.len(), "global room now active, announcing to peers"); info!(room = %room, count = participants.len(), "global room now active, announcing to peers");
let msg = SignalMessage::GlobalRoomActive { room, participants }; let msg = SignalMessage::GlobalRoomActive {
version: default_signal_version(),
room,
participants,
};
let transports: Vec<Arc<QuinnTransport>> = { let transports: Vec<Arc<QuinnTransport>> = {
let links = fm.peer_links.lock().await; let links = fm.peer_links.lock().await;
links.values().map(|l| l.transport.clone()).collect() links.values().map(|l| l.transport.clone()).collect()
@@ -502,7 +537,10 @@ async fn run_room_event_dispatcher(
Ok(RoomEvent::LocalLeave { room }) => { Ok(RoomEvent::LocalLeave { room }) => {
if fm.is_global_room(&room) { if fm.is_global_room(&room) {
info!(room = %room, "global room now inactive, announcing to peers"); info!(room = %room, "global room now inactive, announcing to peers");
let msg = SignalMessage::GlobalRoomInactive { room }; let msg = SignalMessage::GlobalRoomInactive {
version: default_signal_version(),
room,
};
let transports: Vec<Arc<QuinnTransport>> = { let transports: Vec<Arc<QuinnTransport>> = {
let links = fm.peer_links.lock().await; let links = fm.peer_links.lock().await;
links.values().map(|l| l.transport.clone()).collect() links.values().map(|l| l.transport.clone()).collect()
@@ -536,7 +574,9 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
let links = fm.peer_links.lock().await; let links = fm.peer_links.lock().await;
let mut stale = Vec::new(); let mut stale = Vec::new();
for (fp, link) in links.iter() { for (fp, link) in links.iter() {
if link.last_seen.elapsed() > stale_threshold && !link.remote_participants.is_empty() { if link.last_seen.elapsed() > stale_threshold
&& !link.remote_participants.is_empty()
{
for room in link.remote_participants.keys() { for room in link.remote_participants.keys() {
stale.push((fp.clone(), room.clone())); stale.push((fp.clone(), room.clone()));
} }
@@ -576,6 +616,7 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
let mut seen = HashSet::new(); let mut seen = HashSet::new();
all_participants.retain(|p| seen.insert(p.fingerprint.clone())); all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
let update = SignalMessage::RoomUpdate { let update = SignalMessage::RoomUpdate {
version: default_signal_version(),
count: all_participants.len() as u32, count: all_participants.len() as u32,
participants: all_participants, participants: all_participants,
}; };
@@ -615,7 +656,10 @@ async fn run_peer_loop(fm: Arc<FederationManager>, peer: PeerConfig) {
} }
/// Connect to a peer relay and send hello. /// Connect to a peer relay and send hello.
async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Arc<QuinnTransport>, anyhow::Error> { async fn connect_to_peer(
fm: &FederationManager,
peer: &PeerConfig,
) -> Result<Arc<QuinnTransport>, anyhow::Error> {
let addr: SocketAddr = peer.url.parse()?; let addr: SocketAddr = peer.url.parse()?;
let client_cfg = wzp_transport::client_config(); let client_cfg = wzp_transport::client_config();
let conn = wzp_transport::connect(&fm.endpoint, addr, "_federation", client_cfg).await?; let conn = wzp_transport::connect(&fm.endpoint, addr, "_federation", client_cfg).await?;
@@ -623,9 +667,12 @@ async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Ar
// Send hello with our TLS fingerprint // Send hello with our TLS fingerprint
let hello = SignalMessage::FederationHello { let hello = SignalMessage::FederationHello {
version: default_signal_version(),
tls_fingerprint: fm.local_tls_fp.clone(), tls_fingerprint: fm.local_tls_fp.clone(),
}; };
transport.send_signal(&hello).await transport
.send_signal(&hello)
.await
.map_err(|e| anyhow::anyhow!("federation hello send failed: {e}"))?; .map_err(|e| anyhow::anyhow!("federation hello send failed: {e}"))?;
info!(peer_url = %peer.url, label = ?peer.label, "federation: connected (hello sent)"); info!(peer_url = %peer.url, label = ?peer.label, "federation: connected (hello sent)");
@@ -642,16 +689,22 @@ async fn run_federation_link(
peer_label: String, peer_label: String,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
// Register peer link + metrics // Register peer link + metrics
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(1); fm.metrics
.federation_peer_status
.with_label_values(&[&peer_label])
.set(1);
{ {
let mut links = fm.peer_links.lock().await; let mut links = fm.peer_links.lock().await;
links.insert(peer_fp.clone(), PeerLink { links.insert(
peer_fp.clone(),
PeerLink {
transport: transport.clone(), transport: transport.clone(),
label: peer_label.clone(), label: peer_label.clone(),
active_rooms: HashSet::new(), active_rooms: HashSet::new(),
remote_participants: HashMap::new(), remote_participants: HashMap::new(),
last_seen: Instant::now(), last_seen: Instant::now(),
}); },
);
} }
// Announce our currently active global rooms to this new peer // Announce our currently active global rooms to this new peer
@@ -665,7 +718,11 @@ async fn run_federation_link(
if fm.is_global_room(room_name) { if fm.is_global_room(room_name) {
let participants = fm.room_mgr.local_participant_list(room_name); let participants = fm.room_mgr.local_participant_list(room_name);
info!(peer = %peer_label, room = %room_name, participants = participants.len(), "announcing local global room to new peer"); info!(peer = %peer_label, room = %room_name, participants = participants.len(), "announcing local global room to new peer");
msgs.push(SignalMessage::GlobalRoomActive { room: room_name.clone(), participants }); msgs.push(SignalMessage::GlobalRoomActive {
version: default_signal_version(),
room: room_name.clone(),
participants,
});
} }
} }
@@ -677,6 +734,7 @@ async fn run_federation_link(
if fm.is_global_room(room) { if fm.is_global_room(room) {
info!(peer = %peer_label, room = %room, via = %link.label, "propagating remote room to new peer"); info!(peer = %peer_label, room = %room, via = %link.label, "propagating remote room to new peer");
msgs.push(SignalMessage::GlobalRoomActive { msgs.push(SignalMessage::GlobalRoomActive {
version: default_signal_version(),
room: room.clone(), room: room.clone(),
participants: participants.clone(), participants: participants.clone(),
}); });
@@ -761,7 +819,10 @@ async fn run_federation_link(
} }
// Cleanup: remove peer link + metrics // Cleanup: remove peer link + metrics
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(0); fm.metrics
.federation_peer_status
.with_label_values(&[&peer_label])
.set(0);
{ {
let mut links = fm.peer_links.lock().await; let mut links = fm.peer_links.lock().await;
links.remove(&peer_fp); links.remove(&peer_fp);
@@ -787,7 +848,9 @@ async fn handle_signal(
} }
match msg { match msg {
SignalMessage::GlobalRoomActive { room, participants } => { SignalMessage::GlobalRoomActive {
room, participants, ..
} => {
if fm.is_global_room(&room) { if fm.is_global_room(&room) {
info!(peer = %peer_label, room = %room, remote_participants = participants.len(), "peer has global room active"); info!(peer = %peer_label, room = %room, remote_participants = participants.len(), "peer has global room active");
let mut links = fm.peer_links.lock().await; let mut links = fm.peer_links.lock().await;
@@ -799,34 +862,44 @@ async fn handle_signal(
fm.metrics.federation_active_rooms.set(total as i64); fm.metrics.federation_active_rooms.set(total as i64);
if let Some(link) = links.get_mut(peer_fp) { if let Some(link) = links.get_mut(peer_fp) {
// Tag remote participants with their relay label // Tag remote participants with their relay label
let tagged: Vec<_> = participants.iter().map(|p| { let tagged: Vec<_> = participants
.iter()
.map(|p| {
let mut tagged = p.clone(); let mut tagged = p.clone();
if tagged.relay_label.is_none() { if tagged.relay_label.is_none() {
tagged.relay_label = Some(link.label.clone()); tagged.relay_label = Some(link.label.clone());
} }
tagged tagged
}).collect(); })
.collect();
link.remote_participants.insert(room.clone(), tagged); link.remote_participants.insert(room.clone(), tagged);
} }
// Propagate to other peers (with relay labels preserved) // Propagate to other peers (with relay labels preserved)
let tagged_for_propagation = if let Some(link) = links.get(peer_fp) { let tagged_for_propagation = if let Some(link) = links.get(peer_fp) {
let label = link.label.clone(); let label = link.label.clone();
participants.iter().map(|p| { participants
.iter()
.map(|p| {
let mut t = p.clone(); let mut t = p.clone();
if t.relay_label.is_none() { if t.relay_label.is_none() {
t.relay_label = Some(label.clone()); t.relay_label = Some(label.clone());
} }
t t
}).collect::<Vec<_>>() })
.collect::<Vec<_>>()
} else { } else {
participants.clone() participants.clone()
}; };
for (fp, link) in links.iter() { for (fp, link) in links.iter() {
if fp != peer_fp { if fp != peer_fp {
let _ = link.transport.send_signal(&SignalMessage::GlobalRoomActive { let _ = link
.transport
.send_signal(&SignalMessage::GlobalRoomActive {
version: default_signal_version(),
room: room.clone(), room: room.clone(),
participants: tagged_for_propagation.clone(), participants: tagged_for_propagation.clone(),
}).await; })
.await;
} }
} }
drop(links); drop(links);
@@ -835,19 +908,25 @@ async fn handle_signal(
// Find the local room name (may be hashed or raw) // Find the local room name (may be hashed or raw)
let active = fm.room_mgr.active_rooms(); let active = fm.room_mgr.active_rooms();
for local_room in &active { for local_room in &active {
if fm.is_global_room(local_room) && fm.resolve_global_room(local_room) == fm.resolve_global_room(&room) { if fm.is_global_room(local_room)
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
{
// Build merged participant list: local + all remote (deduped) // Build merged participant list: local + all remote (deduped)
let mut all_participants = fm.room_mgr.local_participant_list(local_room); let mut all_participants = fm.room_mgr.local_participant_list(local_room);
{ {
let links = fm.peer_links.lock().await; let links = fm.peer_links.lock().await;
for link in links.values() { for link in links.values() {
if let Some(ref canonical) = fm.resolve_global_room(local_room) { if let Some(ref canonical) = fm.resolve_global_room(local_room) {
if let Some(remote) = link.remote_participants.get(canonical.as_str()) { if let Some(remote) =
link.remote_participants.get(canonical.as_str())
{
all_participants.extend(remote.iter().cloned()); all_participants.extend(remote.iter().cloned());
} }
// Also check raw room name, but only if different from canonical // Also check raw room name, but only if different from canonical
if canonical != local_room { if canonical != local_room {
if let Some(remote) = link.remote_participants.get(local_room) { if let Some(remote) =
link.remote_participants.get(local_room)
{
all_participants.extend(remote.iter().cloned()); all_participants.extend(remote.iter().cloned());
} }
} }
@@ -858,6 +937,7 @@ async fn handle_signal(
let mut seen = HashSet::new(); let mut seen = HashSet::new();
all_participants.retain(|p| seen.insert(p.fingerprint.clone())); all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
let update = SignalMessage::RoomUpdate { let update = SignalMessage::RoomUpdate {
version: default_signal_version(),
count: all_participants.len() as u32, count: all_participants.len() as u32,
participants: all_participants, participants: all_participants,
}; };
@@ -868,7 +948,7 @@ async fn handle_signal(
} }
} }
} }
SignalMessage::GlobalRoomInactive { room } => { SignalMessage::GlobalRoomInactive { room, .. } => {
info!(peer = %peer_label, room = %room, "peer global room now inactive"); info!(peer = %peer_label, room = %room, "peer global room now inactive");
let mut links = fm.peer_links.lock().await; let mut links = fm.peer_links.lock().await;
if let Some(link) = links.get_mut(peer_fp) { if let Some(link) = links.get_mut(peer_fp) {
@@ -890,7 +970,9 @@ async fn handle_signal(
let canonical = fm.resolve_global_room(&room); let canonical = fm.resolve_global_room(&room);
let mut result = Vec::new(); let mut result = Vec::new();
for (fp, link) in links.iter() { for (fp, link) in links.iter() {
if fp == peer_fp { continue; } if fp == peer_fp {
continue;
}
if let Some(ref c) = canonical { if let Some(ref c) = canonical {
if let Some(remote) = link.remote_participants.get(c.as_str()) { if let Some(remote) = link.remote_participants.get(c.as_str()) {
result.extend(remote.iter().cloned()); result.extend(remote.iter().cloned());
@@ -904,11 +986,16 @@ async fn handle_signal(
// Propagate to other peers: send updated GlobalRoomActive with revised list, // Propagate to other peers: send updated GlobalRoomActive with revised list,
// or GlobalRoomInactive if no participants remain anywhere // or GlobalRoomInactive if no participants remain anywhere
let local_active = fm.room_mgr.active_rooms().iter().any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room)); let local_active = fm
.room_mgr
.active_rooms()
.iter()
.any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room));
let has_remaining = !remaining_remote.is_empty() || local_active; let has_remaining = !remaining_remote.is_empty() || local_active;
// Collect peer transports to send to (avoid holding lock across await) // Collect peer transports to send to (avoid holding lock across await)
let peer_sends: Vec<_> = links.iter() let peer_sends: Vec<_> = links
.iter()
.filter(|(fp, _)| *fp != peer_fp) .filter(|(fp, _)| *fp != peer_fp)
.map(|(_, link)| link.transport.clone()) .map(|(_, link)| link.transport.clone())
.collect(); .collect();
@@ -920,12 +1007,14 @@ async fn handle_signal(
if local_active { if local_active {
for local_room in fm.room_mgr.active_rooms() { for local_room in fm.room_mgr.active_rooms() {
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) { if fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
updated_participants.extend(fm.room_mgr.local_participant_list(&local_room)); updated_participants
.extend(fm.room_mgr.local_participant_list(&local_room));
break; break;
} }
} }
} }
let msg = SignalMessage::GlobalRoomActive { let msg = SignalMessage::GlobalRoomActive {
version: default_signal_version(),
room: room.clone(), room: room.clone(),
participants: updated_participants, participants: updated_participants,
}; };
@@ -934,7 +1023,10 @@ async fn handle_signal(
} }
} else { } else {
// No participants left anywhere — propagate inactive // No participants left anywhere — propagate inactive
let msg = SignalMessage::GlobalRoomInactive { room: room.clone() }; let msg = SignalMessage::GlobalRoomInactive {
version: default_signal_version(),
room: room.clone(),
};
for transport in &peer_sends { for transport in &peer_sends {
let _ = transport.send_signal(&msg).await; let _ = transport.send_signal(&msg).await;
} }
@@ -943,13 +1035,16 @@ async fn handle_signal(
// Broadcast updated RoomUpdate to local clients (remote participant removed) // Broadcast updated RoomUpdate to local clients (remote participant removed)
let active = fm.room_mgr.active_rooms(); let active = fm.room_mgr.active_rooms();
for local_room in &active { for local_room in &active {
if fm.is_global_room(local_room) && fm.resolve_global_room(local_room) == fm.resolve_global_room(&room) { if fm.is_global_room(local_room)
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
{
let mut all_participants = fm.room_mgr.local_participant_list(local_room); let mut all_participants = fm.room_mgr.local_participant_list(local_room);
all_participants.extend(remaining_remote.iter().cloned()); all_participants.extend(remaining_remote.iter().cloned());
// Deduplicate by fingerprint // Deduplicate by fingerprint
let mut seen = HashSet::new(); let mut seen = HashSet::new();
all_participants.retain(|p| seen.insert(p.fingerprint.clone())); all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
let update = SignalMessage::RoomUpdate { let update = SignalMessage::RoomUpdate {
version: default_signal_version(),
count: all_participants.len() as u32, count: all_participants.len() as u32,
participants: all_participants, participants: all_participants,
}; };
@@ -972,7 +1067,11 @@ async fn handle_signal(
// Loop prevention: drop any forward whose origin matches // Loop prevention: drop any forward whose origin matches
// our own federation TLS fingerprint. With // our own federation TLS fingerprint. With
// broadcast-to-all-peers this prevents A→B→A echo loops. // broadcast-to-all-peers this prevents A→B→A echo loops.
SignalMessage::FederatedSignalForward { inner, origin_relay_fp } => { SignalMessage::FederatedSignalForward {
inner,
origin_relay_fp,
..
} => {
if origin_relay_fp == fm.local_tls_fp { if origin_relay_fp == fm.local_tls_fp {
tracing::debug!( tracing::debug!(
peer = %peer_label, peer = %peer_label,
@@ -1016,12 +1115,10 @@ async fn handle_signal(
} }
/// Handle an incoming federation datagram (room-hash-tagged media). /// Handle an incoming federation datagram (room-hash-tagged media).
async fn handle_datagram( async fn handle_datagram(fm: &Arc<FederationManager>, source_peer_fp: &str, data: Bytes) {
fm: &Arc<FederationManager>, if data.len() < 12 {
source_peer_fp: &str, return;
data: Bytes, } // 8-byte hash + min packet
) {
if data.len() < 12 { return; } // 8-byte hash + min packet
let mut rh = [0u8; 8]; let mut rh = [0u8; 8];
rh.copy_from_slice(&data[..8]); rh.copy_from_slice(&data[..8]);
@@ -1030,7 +1127,8 @@ async fn handle_datagram(
let pkt = match wzp_proto::MediaPacket::from_bytes(media_bytes.clone()) { let pkt = match wzp_proto::MediaPacket::from_bytes(media_bytes.clone()) {
Some(pkt) => pkt, Some(pkt) => pkt,
None => { None => {
fm.event_log.emit(Event::new("federation_ingress_malformed").len(data.len())); fm.event_log
.emit(Event::new("federation_ingress_malformed").len(data.len()));
return; return;
} }
}; };
@@ -1038,13 +1136,22 @@ async fn handle_datagram(
// Event log: federation ingress // Event log: federation ingress
let peer_label = { let peer_label = {
let links = fm.peer_links.lock().await; let links = fm.peer_links.lock().await;
links.get(source_peer_fp).map(|l| l.label.clone()).unwrap_or_default() links
.get(source_peer_fp)
.map(|l| l.label.clone())
.unwrap_or_default()
}; };
fm.event_log.emit(Event::new("federation_ingress").packet(&pkt).peer(&peer_label)); fm.event_log.emit(
Event::new("federation_ingress")
.packet(&pkt)
.peer(&peer_label),
);
// Count inbound federation packet + update last_seen // Count inbound federation packet + update last_seen
fm.metrics.federation_packets_forwarded fm.metrics
.with_label_values(&[source_peer_fp, "in"]).inc(); .federation_packets_forwarded
.with_label_values(&[source_peer_fp, "in"])
.inc();
{ {
let mut links = fm.peer_links.lock().await; let mut links = fm.peer_links.lock().await;
if let Some(link) = links.get_mut(source_peer_fp) { if let Some(link) = links.get_mut(source_peer_fp) {
@@ -1065,7 +1172,11 @@ async fn handle_datagram(
{ {
let mut dedup = fm.dedup.lock().await; let mut dedup = fm.dedup.lock().await;
if dedup.is_dup(&rh, pkt.header.seq, payload_hash) { if dedup.is_dup(&rh, pkt.header.seq, payload_hash) {
fm.event_log.emit(Event::new("dedup_drop").seq(pkt.header.seq).peer(&peer_label)); fm.event_log.emit(
Event::new("dedup_drop")
.seq(pkt.header.seq)
.peer(&peer_label),
);
return; return;
} }
} }
@@ -1074,18 +1185,33 @@ async fn handle_datagram(
let room_name = { let room_name = {
let active = fm.room_mgr.active_rooms(); let active = fm.room_mgr.active_rooms();
// First: check local rooms (has participants) // First: check local rooms (has participants)
active.iter().find(|r| room_hash(r) == rh).cloned() active
.or_else(|| active.iter().find(|r| fm.global_room_hash(r) == rh).cloned()) .iter()
.find(|r| room_hash(r) == rh)
.cloned()
.or_else(|| {
active
.iter()
.find(|r| fm.global_room_hash(r) == rh)
.cloned()
})
// Second: check static global room config (hub relay may have no local participants) // Second: check static global room config (hub relay may have no local participants)
.or_else(|| { .or_else(|| {
fm.global_rooms.iter().find(|name| room_hash(name) == rh).cloned() fm.global_rooms
.iter()
.find(|name| room_hash(name) == rh)
.cloned()
}) })
}; };
let room_name = match room_name { let room_name = match room_name {
Some(r) => r, Some(r) => r,
None => { None => {
fm.event_log.emit(Event::new("room_not_found").seq(pkt.header.seq).peer(&peer_label)); fm.event_log.emit(
Event::new("room_not_found")
.seq(pkt.header.seq)
.peer(&peer_label),
);
// Phase 4.1 diagnostic: log the hash + active rooms // Phase 4.1 diagnostic: log the hash + active rooms
// so we can diagnose cross-relay call-* media routing // so we can diagnose cross-relay call-* media routing
// failures. This fires when a peer relay sends media // failures. This fires when a peer relay sends media
@@ -1107,10 +1233,15 @@ async fn handle_datagram(
// Rate limit per room // Rate limit per room
if FEDERATION_RATE_LIMIT_PPS > 0 { if FEDERATION_RATE_LIMIT_PPS > 0 {
let mut limiters = fm.rate_limiters.lock().await; let mut limiters = fm.rate_limiters.lock().await;
let limiter = limiters.entry(room_name.clone()) let limiter = limiters
.entry(room_name.clone())
.or_insert_with(|| RateLimiter::new(FEDERATION_RATE_LIMIT_PPS)); .or_insert_with(|| RateLimiter::new(FEDERATION_RATE_LIMIT_PPS));
if !limiter.allow() { if !limiter.allow() {
fm.event_log.emit(Event::new("rate_limit_drop").room(&room_name).seq(pkt.header.seq)); fm.event_log.emit(
Event::new("rate_limit_drop")
.room(&room_name)
.seq(pkt.header.seq),
);
return; return;
} }
} }
@@ -1122,14 +1253,26 @@ async fn handle_datagram(
match sender { match sender {
room::ParticipantSender::Quic(t) => { room::ParticipantSender::Quic(t) => {
if let Err(e) = t.send_raw_datagram(&media_bytes) { if let Err(e) = t.send_raw_datagram(&media_bytes) {
fm.event_log.emit(Event::new("local_deliver_error").room(&room_name).seq(pkt.header.seq).reason(&e.to_string())); fm.event_log.emit(
Event::new("local_deliver_error")
.room(&room_name)
.seq(pkt.header.seq)
.reason(&e.to_string()),
);
warn!("federation local delivery error: {e}"); warn!("federation local delivery error: {e}");
} }
} }
room::ParticipantSender::WebSocket(_) => { let _ = sender.send_raw(&pkt.payload).await; } room::ParticipantSender::WebSocket(_) => {
let _ = sender.send_raw(&pkt.payload).await;
} }
} }
fm.event_log.emit(Event::new("local_deliver").room(&room_name).seq(pkt.header.seq).to_count(locals.len())); }
fm.event_log.emit(
Event::new("local_deliver")
.room(&room_name)
.seq(pkt.header.seq)
.to_count(locals.len()),
);
// Multi-hop: forward to ALL other connected peers (not the source) // Multi-hop: forward to ALL other connected peers (not the source)
// Don't filter by active_rooms — the receiving peer decides whether to deliver // Don't filter by active_rooms — the receiving peer decides whether to deliver

View File

@@ -4,7 +4,7 @@
//! recv `CallOffer` → verify → generate ephemeral → derive session → send `CallAnswer`. //! recv `CallOffer` → verify → generate ephemeral → derive session → send `CallAnswer`.
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange}; use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage}; use wzp_proto::{MediaTransport, QualityProfile, SignalMessage, default_signal_version};
/// Accept the relay (callee) side of the cryptographic handshake. /// Accept the relay (callee) side of the cryptographic handshake.
/// ///
@@ -20,30 +20,72 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
pub async fn accept_handshake( pub async fn accept_handshake(
transport: &dyn MediaTransport, transport: &dyn MediaTransport,
seed: &[u8; 32], seed: &[u8; 32],
) -> Result<(Box<dyn CryptoSession>, QualityProfile, String, Option<String>), anyhow::Error> { ) -> Result<
(
Box<dyn CryptoSession>,
QualityProfile,
String,
Option<String>,
),
anyhow::Error,
> {
// 1. Receive CallOffer // 1. Receive CallOffer
let offer = transport let offer = transport
.recv_signal() .recv_signal()
.await? .await?
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?; .ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?;
let (caller_identity_pub, caller_ephemeral_pub, caller_signature, supported_profiles, caller_alias) = let (
match offer { caller_identity_pub,
caller_ephemeral_pub,
caller_signature,
supported_profiles,
caller_alias,
protocol_version,
caller_video_codecs,
) = match offer {
SignalMessage::CallOffer { SignalMessage::CallOffer {
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature, signature,
supported_profiles, supported_profiles,
alias, alias,
} => (identity_pub, ephemeral_pub, signature, supported_profiles, alias), protocol_version,
supported_versions: _,
video_codecs,
..
} => (
identity_pub,
ephemeral_pub,
signature,
supported_profiles,
alias,
protocol_version,
video_codecs,
),
other => { other => {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"expected CallOffer, got {:?}", "expected CallOffer, got {:?}",
std::mem::discriminant(&other) std::mem::discriminant(&other)
)) ));
} }
}; };
// 1a. Protocol version check — we only speak v2.
if protocol_version != 2 {
let mismatch = SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
server_supported: vec![2],
},
call_id: None,
};
let _ = transport.send_signal(&mismatch).await;
return Err(anyhow::anyhow!(
"protocol version mismatch: client requested {protocol_version}, server supports [2]"
));
}
// 2. Verify caller's signature over (ephemeral_pub || "call-offer") // 2. Verify caller's signature over (ephemeral_pub || "call-offer")
let mut verify_data = Vec::with_capacity(32 + 10); let mut verify_data = Vec::with_capacity(32 + 10);
verify_data.extend_from_slice(&caller_ephemeral_pub); verify_data.extend_from_slice(&caller_ephemeral_pub);
@@ -69,23 +111,28 @@ pub async fn accept_handshake(
// Choose the best supported profile (prefer GOOD > DEGRADED > CATASTROPHIC) // Choose the best supported profile (prefer GOOD > DEGRADED > CATASTROPHIC)
let chosen_profile = choose_profile(&supported_profiles); let chosen_profile = choose_profile(&supported_profiles);
// Pick the first video codec the caller supports (relay forwards all video).
let video_codec = caller_video_codecs.into_iter().next();
// 6. Send CallAnswer // 6. Send CallAnswer
let answer = SignalMessage::CallAnswer { let answer = SignalMessage::CallAnswer {
version: default_signal_version(),
identity_pub, identity_pub,
ephemeral_pub, ephemeral_pub,
signature, signature,
chosen_profile, chosen_profile,
video_codec,
}; };
transport.send_signal(&answer).await?; transport.send_signal(&answer).await?;
// Derive caller fingerprint: SHA-256(Ed25519 pub)[:16], formatted as xxxx:xxxx:... // Derive caller fingerprint: SHA-256(Ed25519 pub)[:16], formatted as xxxx:xxxx:...
// Must match the format used in signal registration and presence. // Must match the format used in signal registration and presence.
let caller_fp = { let caller_fp = {
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
let hash = Sha256::digest(&caller_identity_pub); let hash = Sha256::digest(&caller_identity_pub);
let fp = wzp_crypto::Fingerprint([ let fp = wzp_crypto::Fingerprint([
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7], hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7], hash[8],
hash[8], hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15], hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15],
]); ]);
fp.to_string() fp.to_string()
}; };
@@ -107,6 +154,7 @@ fn choose_profile(_supported: &[QualityProfile]) -> QualityProfile {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use wzp_proto::CodecId;
#[test] #[test]
fn choose_profile_picks_highest_bitrate() { fn choose_profile_picks_highest_bitrate() {
@@ -124,4 +172,35 @@ mod tests {
let chosen = choose_profile(&[]); let chosen = choose_profile(&[]);
assert_eq!(chosen, QualityProfile::GOOD); assert_eq!(chosen, QualityProfile::GOOD);
} }
// ── Video codec negotiation ───────────────────────────────────────
#[test]
fn video_codec_picks_first_offered() {
let codecs = vec![CodecId::Av1Main, CodecId::H264Baseline, CodecId::H265Main];
let chosen: Option<CodecId> = codecs.into_iter().next();
assert_eq!(chosen, Some(CodecId::Av1Main));
}
#[test]
fn video_codec_none_when_no_codecs_offered() {
let codecs: Vec<CodecId> = vec![];
let chosen: Option<CodecId> = codecs.into_iter().next();
assert_eq!(chosen, None);
}
#[test]
fn video_codec_single_codec_is_selected() {
let codecs = vec![CodecId::H265Main];
let chosen: Option<CodecId> = codecs.into_iter().next();
assert_eq!(chosen, Some(CodecId::H265Main));
}
#[test]
fn video_codec_order_is_preserved() {
// The relay must pick the FIRST codec as-offered, not sort or re-rank.
let codecs = vec![CodecId::H264Baseline, CodecId::Av1Main];
let chosen: Option<CodecId> = codecs.into_iter().next();
assert_eq!(chosen, Some(CodecId::H264Baseline));
}
} }

View File

@@ -7,22 +7,27 @@
//! It operates on FEC-protected packets, managing loss recovery and adaptive //! It operates on FEC-protected packets, managing loss recovery and adaptive
//! quality transitions. //! quality transitions.
pub mod audio_scorer;
pub mod auth; pub mod auth;
pub mod call_registry; pub mod call_registry;
pub mod config; pub mod config;
pub mod conformance;
pub mod event_log; pub mod event_log;
pub mod federation; pub mod federation;
pub mod signal_hub;
pub mod handshake; pub mod handshake;
pub mod metrics; pub mod metrics;
pub mod pipeline; pub mod pipeline;
pub mod presence; pub mod presence;
pub mod probe; pub mod probe;
pub mod relay_link; pub mod relay_link;
pub mod response_policy;
pub mod room; pub mod room;
pub mod route; pub mod route;
pub mod session_mgr; pub mod session_mgr;
pub mod signal_hub;
pub mod trunk; pub mod trunk;
pub mod verdict;
pub mod video_scorer;
pub mod ws; pub mod ws;
pub use config::RelayConfig; pub use config::RelayConfig;

View File

@@ -8,15 +8,15 @@
//! The web bridge connects with room name as SNI. //! The web bridge connects with room name as SNI.
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration; use std::time::Duration;
use clap::Parser; use clap::Parser;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use wzp_proto::{MediaTransport, SignalMessage}; use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
use wzp_relay::config::RelayConfig; use wzp_relay::config::RelayConfig;
use wzp_relay::metrics::RelayMetrics; use wzp_relay::metrics::RelayMetrics;
use wzp_relay::pipeline::{PipelineConfig, RelayPipeline}; use wzp_relay::pipeline::{PipelineConfig, RelayPipeline};
@@ -116,7 +116,9 @@ fn parse_args() -> CliResult {
} }
// Track if we need to create the config after identity is known // Track if we need to create the config after identity is known
let config_needs_create = args.config_file.as_ref() let config_needs_create = args
.config_file
.as_ref()
.map(|p| !std::path::Path::new(p).exists()) .map(|p| !std::path::Path::new(p).exists())
.unwrap_or(false); .unwrap_or(false);
@@ -125,8 +127,7 @@ fn parse_args() -> CliResult {
// Will be re-created with personalized info after identity is loaded // Will be re-created with personalized info after identity is loaded
RelayConfig::default() RelayConfig::default()
} else { } else {
wzp_relay::config::load_config(path) wzp_relay::config::load_config(path).unwrap_or_else(|e| {
.unwrap_or_else(|e| {
eprintln!("failed to load config from {path}: {e}"); eprintln!("failed to load config from {path}: {e}");
std::process::exit(1); std::process::exit(1);
}) })
@@ -164,7 +165,9 @@ fn parse_args() -> CliResult {
config.static_dir = Some(dir); config.static_dir = Some(dir);
} }
for name in args.global_room { for name in args.global_room {
config.global_rooms.push(wzp_relay::config::GlobalRoomConfig { name }); config
.global_rooms
.push(wzp_relay::config::GlobalRoomConfig { name });
} }
if let Some(tap) = args.debug_tap { if let Some(tap) = args.debug_tap {
config.debug_tap = Some(tap); config.debug_tap = Some(tap);
@@ -199,7 +202,9 @@ async fn run_upstream(
let mut pipe = pipeline.lock().await; let mut pipe = pipeline.lock().await;
let decoded = pipe.ingest(pkt); let decoded = pipe.ingest(pkt);
let mut out = Vec::new(); let mut out = Vec::new();
for p in decoded { out.extend(pipe.prepare_outbound(p)); } for p in decoded {
out.extend(pipe.prepare_outbound(p));
}
out out
}; };
for p in &outbound { for p in &outbound {
@@ -208,10 +213,18 @@ async fn run_upstream(
return; return;
} }
} }
stats.upstream_packets.fetch_add(outbound.len() as u64, Ordering::Relaxed); stats
.upstream_packets
.fetch_add(outbound.len() as u64, Ordering::Relaxed);
}
Ok(None) => {
info!("client disconnected (upstream)");
break;
}
Err(e) => {
error!("upstream recv: {e}");
break;
} }
Ok(None) => { info!("client disconnected (upstream)"); break; }
Err(e) => { error!("upstream recv: {e}"); break; }
} }
} }
} }
@@ -229,7 +242,9 @@ async fn run_downstream(
let mut pipe = pipeline.lock().await; let mut pipe = pipeline.lock().await;
let decoded = pipe.ingest(pkt); let decoded = pipe.ingest(pkt);
let mut out = Vec::new(); let mut out = Vec::new();
for p in decoded { out.extend(pipe.prepare_outbound(p)); } for p in decoded {
out.extend(pipe.prepare_outbound(p));
}
out out
}; };
for p in &outbound { for p in &outbound {
@@ -238,10 +253,18 @@ async fn run_downstream(
return; return;
} }
} }
stats.downstream_packets.fetch_add(outbound.len() as u64, Ordering::Relaxed); stats
.downstream_packets
.fetch_add(outbound.len() as u64, Ordering::Relaxed);
}
Ok(None) => {
info!("remote disconnected (downstream)");
break;
}
Err(e) => {
error!("downstream recv: {e}");
break;
} }
Ok(None) => { info!("remote disconnected (downstream)"); break; }
Err(e) => { error!("downstream recv: {e}"); break; }
} }
} }
} }
@@ -266,7 +289,12 @@ const BUILD_GIT_HASH: &str = env!("WZP_BUILD_HASH");
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let CliResult { config, identity_path, config_file, config_needs_create } = parse_args(); let CliResult {
config,
identity_path,
config_file,
config_needs_create,
} = parse_args();
tracing_subscriber::fmt().init(); tracing_subscriber::fmt().init();
info!(version = BUILD_GIT_HASH, "wzp-relay build"); info!(version = BUILD_GIT_HASH, "wzp-relay build");
rustls::crypto::ring::default_provider() rustls::crypto::ring::default_provider()
@@ -303,7 +331,10 @@ async fn main() -> anyhow::Result<()> {
info!("loaded relay identity from {}", id_path.display()); info!("loaded relay identity from {}", id_path.display());
s s
} else { } else {
warn!("corrupt identity file {}, generating new", id_path.display()); warn!(
"corrupt identity file {}, generating new",
id_path.display()
);
let s = wzp_crypto::Seed::generate(); let s = wzp_crypto::Seed::generate();
let hex: String = s.0.iter().map(|b| format!("{b:02x}")).collect(); let hex: String = s.0.iter().map(|b| format!("{b:02x}")).collect();
let _ = std::fs::write(&id_path, &hex); let _ = std::fs::write(&id_path, &hex);
@@ -386,7 +417,7 @@ async fn main() -> anyhow::Result<()> {
} else { } else {
// Probe via a dummy "connected" UDP socket. Never actually sends. // Probe via a dummy "connected" UDP socket. Never actually sends.
match std::net::UdpSocket::bind("0.0.0.0:0") match std::net::UdpSocket::bind("0.0.0.0:0")
.and_then(|s| { s.connect("8.8.8.8:80").map(|_| s) }) .and_then(|s| s.connect("8.8.8.8:80").map(|_| s))
.and_then(|s| s.local_addr()) .and_then(|s| s.local_addr())
{ {
Ok(a) if !a.ip().is_loopback() => a.ip(), Ok(a) if !a.ip().is_loopback() => a.ip(),
@@ -398,8 +429,9 @@ async fn main() -> anyhow::Result<()> {
info!(%advertised_addr_str, "relay advertised address for CallSetup"); info!(%advertised_addr_str, "relay advertised address for CallSetup");
// Forward mode // Forward mode
let remote_transport: Option<Arc<wzp_transport::QuinnTransport>> = let remote_transport: Option<Arc<wzp_transport::QuinnTransport>> = if let Some(remote_addr) =
if let Some(remote_addr) = config.remote_relay { config.remote_relay
{
info!(%remote_addr, "forward mode → remote relay"); info!(%remote_addr, "forward mode → remote relay");
let client_cfg = wzp_transport::client_config(); let client_cfg = wzp_transport::client_config();
let conn = wzp_transport::connect(&endpoint, remote_addr, "localhost", client_cfg).await?; let conn = wzp_transport::connect(&endpoint, remote_addr, "localhost", client_cfg).await?;
@@ -414,15 +446,15 @@ async fn main() -> anyhow::Result<()> {
// Event log for protocol analysis // Event log for protocol analysis
let event_log = wzp_relay::event_log::start_event_log( let event_log = wzp_relay::event_log::start_event_log(
config.event_log.as_ref().map(std::path::PathBuf::from) config.event_log.as_ref().map(std::path::PathBuf::from),
); );
// Federation manager // Federation manager
let global_room_set: std::collections::HashSet<String> = config.global_rooms.iter() let global_room_set: std::collections::HashSet<String> =
.map(|g| g.name.clone()) config.global_rooms.iter().map(|g| g.name.clone()).collect();
.collect();
let federation_mgr = if !config.peers.is_empty() || !config.trusted.is_empty() || !global_room_set.is_empty() { let federation_mgr =
if !config.peers.is_empty() || !config.trusted.is_empty() || !global_room_set.is_empty() {
let fm = Arc::new(wzp_relay::federation::FederationManager::new( let fm = Arc::new(wzp_relay::federation::FederationManager::new(
config.peers.clone(), config.peers.clone(),
config.trusted.clone(), config.trusted.clone(),
@@ -608,6 +640,7 @@ async fn main() -> anyhow::Result<()> {
.send_to( .send_to(
&caller_fp, &caller_fp,
&SignalMessage::Hangup { &SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}, },
@@ -624,14 +657,15 @@ async fn main() -> anyhow::Result<()> {
// active, then read back everything needed to // active, then read back everything needed to
// cross-wire into the local CallSetup. // cross-wire into the local CallSetup.
let room_name = format!("call-{call_id}"); let room_name = format!("call-{call_id}");
let (callee_addr_for_setup, callee_local_for_setup, callee_mapped_for_setup) = { let (
callee_addr_for_setup,
callee_local_for_setup,
callee_mapped_for_setup,
) = {
let mut reg = call_registry_d.lock().await; let mut reg = call_registry_d.lock().await;
reg.set_active(call_id, accept_mode, room_name.clone()); reg.set_active(call_id, accept_mode, room_name.clone());
reg.set_peer_relay_fp(call_id, Some(origin_relay_fp.clone())); reg.set_peer_relay_fp(call_id, Some(origin_relay_fp.clone()));
reg.set_callee_reflexive_addr( reg.set_callee_reflexive_addr(call_id, callee_reflexive_addr.clone());
call_id,
callee_reflexive_addr.clone(),
);
reg.set_callee_local_addrs(call_id, callee_local_addrs.clone()); reg.set_callee_local_addrs(call_id, callee_local_addrs.clone());
reg.set_callee_mapped_addr(call_id, callee_mapped_addr.clone()); reg.set_callee_mapped_addr(call_id, callee_mapped_addr.clone());
let c = reg.get(call_id); let c = reg.get(call_id);
@@ -652,6 +686,7 @@ async fn main() -> anyhow::Result<()> {
// Emit the LOCAL CallSetup to our local caller. // Emit the LOCAL CallSetup to our local caller.
let setup = SignalMessage::CallSetup { let setup = SignalMessage::CallSetup {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
room: room_name.clone(), room: room_name.clone(),
relay_addr: advertised_addr_d.clone(), relay_addr: advertised_addr_d.clone(),
@@ -670,7 +705,7 @@ async fn main() -> anyhow::Result<()> {
); );
} }
SignalMessage::CallRinging { ref call_id } => { SignalMessage::CallRinging { ref call_id, .. } => {
// Forward to local caller for "ringing..." UX. // Forward to local caller for "ringing..." UX.
let caller_fp = { let caller_fp = {
let reg = call_registry_d.lock().await; let reg = call_registry_d.lock().await;
@@ -762,7 +797,9 @@ async fn main() -> anyhow::Result<()> {
let relay_seed_bytes = relay_seed.0; let relay_seed_bytes = relay_seed.0;
let metrics = metrics.clone(); let metrics = metrics.clone();
let trunking_enabled = config.trunking_enabled; let trunking_enabled = config.trunking_enabled;
let debug_tap = config.debug_tap.as_ref().map(|filter| room::DebugTap { room_filter: filter.clone() }); let debug_tap = config.debug_tap.as_ref().map(|filter| room::DebugTap {
room_filter: filter.clone(),
});
let presence = presence.clone(); let presence = presence.clone();
let route_resolver = route_resolver.clone(); let route_resolver = route_resolver.clone();
let federation_mgr = federation_mgr.clone(); let federation_mgr = federation_mgr.clone();
@@ -771,7 +808,9 @@ async fn main() -> anyhow::Result<()> {
let advertised_addr_str = advertised_addr_str.clone(); let advertised_addr_str = advertised_addr_str.clone();
// Phase 8: relay region + peer addresses for RegisterPresenceAck // Phase 8: relay region + peer addresses for RegisterPresenceAck
let relay_region = config.region.clone(); let relay_region = config.region.clone();
let relay_peers_for_ack: Vec<String> = config.peers.iter() let relay_peers_for_ack: Vec<String> = config
.peers
.iter()
.filter_map(|p| { .filter_map(|p| {
let label = p.label.as_deref().unwrap_or("peer"); let label = p.label.as_deref().unwrap_or("peer");
Some(format!("{label}|{}", p.url)) Some(format!("{label}|{}", p.url))
@@ -800,9 +839,7 @@ async fn main() -> anyhow::Result<()> {
let room_name = connection let room_name = connection
.handshake_data() .handshake_data()
.and_then(|hd| { .and_then(|hd| hd.downcast::<quinn::crypto::rustls::HandshakeData>().ok())
hd.downcast::<quinn::crypto::rustls::HandshakeData>().ok()
})
.and_then(|hd| hd.server_name.clone()) .and_then(|hd| hd.server_name.clone())
.unwrap_or_else(|| "default".to_string()); .unwrap_or_else(|| "default".to_string());
@@ -831,18 +868,28 @@ async fn main() -> anyhow::Result<()> {
info!(%addr, "probe connection detected, entering Ping/Pong + presence responder"); info!(%addr, "probe connection detected, entering Ping/Pong + presence responder");
loop { loop {
match transport.recv_signal().await { match transport.recv_signal().await {
Ok(Some(wzp_proto::SignalMessage::Ping { timestamp_ms })) => { Ok(Some(wzp_proto::SignalMessage::Ping { timestamp_ms, .. })) => {
if let Err(e) = transport.send_signal( if let Err(e) = transport
&wzp_proto::SignalMessage::Pong { timestamp_ms }, .send_signal(&wzp_proto::SignalMessage::Pong {
).await { version: default_signal_version(),
timestamp_ms,
})
.await
{
error!(%addr, "probe pong send error: {e}"); error!(%addr, "probe pong send error: {e}");
break; break;
} }
} }
Ok(Some(wzp_proto::SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => { Ok(Some(wzp_proto::SignalMessage::PresenceUpdate {
fingerprints,
relay_addr,
..
})) => {
// A peer relay is telling us which fingerprints it has // A peer relay is telling us which fingerprints it has
let peer_addr: std::net::SocketAddr = relay_addr.parse().unwrap_or(addr); let peer_addr: std::net::SocketAddr =
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect(); relay_addr.parse().unwrap_or(addr);
let fps: std::collections::HashSet<String> =
fingerprints.into_iter().collect();
{ {
let mut reg = presence.lock().await; let mut reg = presence.lock().await;
reg.update_peer(peer_addr, fps); reg.update_peer(peer_addr, fps);
@@ -853,6 +900,7 @@ async fn main() -> anyhow::Result<()> {
reg.local_fingerprints().into_iter().collect() reg.local_fingerprints().into_iter().collect()
}; };
let reply = wzp_proto::SignalMessage::PresenceUpdate { let reply = wzp_proto::SignalMessage::PresenceUpdate {
version: default_signal_version(),
fingerprints: local_fps, fingerprints: local_fps,
relay_addr: addr.to_string(), relay_addr: addr.to_string(),
}; };
@@ -861,7 +909,9 @@ async fn main() -> anyhow::Result<()> {
break; break;
} }
} }
Ok(Some(wzp_proto::SignalMessage::RouteQuery { fingerprint, ttl })) => { Ok(Some(wzp_proto::SignalMessage::RouteQuery {
fingerprint, ttl, ..
})) => {
// Look up the fingerprint in our local registry // Look up the fingerprint in our local registry
let reg = presence.lock().await; let reg = presence.lock().await;
let route = route_resolver.resolve(&reg, &fingerprint); let route = route_resolver.resolve(&reg, &fingerprint);
@@ -871,9 +921,13 @@ async fn main() -> anyhow::Result<()> {
wzp_relay::route::Route::Local => { wzp_relay::route::Route::Local => {
(true, vec![route_resolver.local_addr().to_string()]) (true, vec![route_resolver.local_addr().to_string()])
} }
wzp_relay::route::Route::DirectPeer(peer_addr) => { wzp_relay::route::Route::DirectPeer(peer_addr) => (
(true, vec![route_resolver.local_addr().to_string(), peer_addr.to_string()]) true,
} vec![
route_resolver.local_addr().to_string(),
peer_addr.to_string(),
],
),
_ => { _ => {
// Not found locally; if ttl > 0 we could forward // Not found locally; if ttl > 0 we could forward
// to other peers (future multi-hop). For now, reply not found. // to other peers (future multi-hop). For now, reply not found.
@@ -885,6 +939,7 @@ async fn main() -> anyhow::Result<()> {
}; };
let reply = wzp_proto::SignalMessage::RouteResponse { let reply = wzp_proto::SignalMessage::RouteResponse {
version: default_signal_version(),
fingerprint, fingerprint,
found, found,
relay_chain, relay_chain,
@@ -918,8 +973,13 @@ async fn main() -> anyhow::Result<()> {
let hello_fp = match tokio::time::timeout( let hello_fp = match tokio::time::timeout(
std::time::Duration::from_secs(5), std::time::Duration::from_secs(5),
transport.recv_signal(), transport.recv_signal(),
).await { )
Ok(Ok(Some(wzp_proto::SignalMessage::FederationHello { tls_fingerprint }))) => tls_fingerprint, .await
{
Ok(Ok(Some(wzp_proto::SignalMessage::FederationHello {
tls_fingerprint,
..
}))) => tls_fingerprint,
_ => { _ => {
warn!(%addr, "federation: no hello received, closing"); warn!(%addr, "federation: no hello received, closing");
return; return;
@@ -955,7 +1015,7 @@ async fn main() -> anyhow::Result<()> {
// Optional auth // Optional auth
let auth_fp: Option<String> = if let Some(ref url) = auth_url { let auth_fp: Option<String> = if let Some(ref url) = auth_url {
match transport.recv_signal().await { match transport.recv_signal().await {
Ok(Some(SignalMessage::AuthToken { token })) => { Ok(Some(SignalMessage::AuthToken { token, .. })) => {
match wzp_relay::auth::validate_token(url, &token).await { match wzp_relay::auth::validate_token(url, &token).await {
Ok(client) => Some(client.fingerprint), Ok(client) => Some(client.fingerprint),
Err(e) => { Err(e) => {
@@ -964,7 +1024,10 @@ async fn main() -> anyhow::Result<()> {
} }
} }
} }
_ => { warn!(%addr, "signal: expected AuthToken"); return; } _ => {
warn!(%addr, "signal: expected AuthToken");
return;
}
} }
} else { } else {
None None
@@ -974,15 +1037,23 @@ async fn main() -> anyhow::Result<()> {
let (client_fp, client_alias) = match tokio::time::timeout( let (client_fp, client_alias) = match tokio::time::timeout(
std::time::Duration::from_secs(10), std::time::Duration::from_secs(10),
transport.recv_signal(), transport.recv_signal(),
).await { )
Ok(Ok(Some(SignalMessage::RegisterPresence { identity_pub, signature: _, alias }))) => { .await
{
Ok(Ok(Some(SignalMessage::RegisterPresence {
identity_pub,
signature: _,
alias,
..
}))) => {
// Compute fingerprint: SHA-256(Ed25519 pub key)[:16], same as Fingerprint type // Compute fingerprint: SHA-256(Ed25519 pub key)[:16], same as Fingerprint type
let fp = { let fp = {
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
let hash = Sha256::digest(&identity_pub); let hash = Sha256::digest(&identity_pub);
let fingerprint = wzp_crypto::Fingerprint([ let fingerprint = wzp_crypto::Fingerprint([
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7], hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6],
hash[8], hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15], hash[7], hash[8], hash[9], hash[10], hash[11], hash[12], hash[13],
hash[14], hash[15],
]); ]);
fingerprint.to_string() fingerprint.to_string()
}; };
@@ -1006,13 +1077,16 @@ async fn main() -> anyhow::Result<()> {
} }
// Send ack // Send ack
let _ = transport.send_signal(&SignalMessage::RegisterPresenceAck { let _ = transport
.send_signal(&SignalMessage::RegisterPresenceAck {
version: default_signal_version(),
success: true, success: true,
error: None, error: None,
relay_build: Some(BUILD_GIT_HASH.to_string()), relay_build: Some(BUILD_GIT_HASH.to_string()),
relay_region: relay_region.clone(), relay_region: relay_region.clone(),
available_relays: relay_peers_for_ack.clone(), available_relays: relay_peers_for_ack.clone(),
}).await; })
.await;
info!(%addr, fingerprint = %client_fp, alias = ?client_alias, "signal client registered"); info!(%addr, fingerprint = %client_fp, alias = ?client_alias, "signal client registered");
@@ -1065,6 +1139,7 @@ async fn main() -> anyhow::Result<()> {
// federation has a matching entry. // federation has a matching entry.
let forwarded = if let Some(ref fm) = federation_mgr { let forwarded = if let Some(ref fm) = federation_mgr {
let forward = SignalMessage::FederatedSignalForward { let forward = SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(msg.clone()), inner: Box::new(msg.clone()),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
@@ -1086,10 +1161,13 @@ async fn main() -> anyhow::Result<()> {
if !forwarded { if !forwarded {
info!(%addr, target = %target_fp, "call target not online (no federation route)"); info!(%addr, target = %target_fp, "call target not online (no federation route)");
let _ = transport.send_signal(&SignalMessage::Hangup { let _ = transport
.send_signal(&SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: None, call_id: None,
}).await; })
.await;
continue; continue;
} }
@@ -1128,9 +1206,12 @@ async fn main() -> anyhow::Result<()> {
// Send ringing to caller immediately // Send ringing to caller immediately
// so the UI shows feedback while the // so the UI shows feedback while the
// federated delivery is in flight. // federated delivery is in flight.
let _ = transport.send_signal(&SignalMessage::CallRinging { let _ = transport
.send_signal(&SignalMessage::CallRinging {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
}).await; })
.await;
continue; continue;
} }
@@ -1141,10 +1222,23 @@ async fn main() -> anyhow::Result<()> {
// injected later into the callee's CallSetup. // injected later into the callee's CallSetup.
{ {
let mut reg = call_registry.lock().await; let mut reg = call_registry.lock().await;
reg.create_call(call_id.clone(), client_fp.clone(), target_fp.clone()); reg.create_call(
reg.set_caller_reflexive_addr(&call_id, caller_addr_for_registry); call_id.clone(),
reg.set_caller_local_addrs(&call_id, caller_local_for_registry); client_fp.clone(),
reg.set_caller_mapped_addr(&call_id, caller_mapped_for_registry); target_fp.clone(),
);
reg.set_caller_reflexive_addr(
&call_id,
caller_addr_for_registry,
);
reg.set_caller_local_addrs(
&call_id,
caller_local_for_registry,
);
reg.set_caller_mapped_addr(
&call_id,
caller_mapped_for_registry,
);
} }
// Forward offer to callee // Forward offer to callee
@@ -1156,9 +1250,12 @@ async fn main() -> anyhow::Result<()> {
// Send ringing to caller // Send ringing to caller
drop(hub); drop(hub);
let _ = transport.send_signal(&SignalMessage::CallRinging { let _ = transport
.send_signal(&SignalMessage::CallRinging {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
}).await; })
.await;
} }
SignalMessage::DirectCallAnswer { SignalMessage::DirectCallAnswer {
@@ -1186,7 +1283,10 @@ async fn main() -> anyhow::Result<()> {
let reg = call_registry.lock().await; let reg = call_registry.lock().await;
match reg.get(&call_id) { match reg.get(&call_id) {
Some(c) => ( Some(c) => (
Some(reg.peer_fingerprint(&call_id, &client_fp).map(|s| s.to_string())), Some(
reg.peer_fingerprint(&call_id, &client_fp)
.map(|s| s.to_string()),
),
c.peer_relay_fp.clone(), c.peer_relay_fp.clone(),
), ),
None => (None, None), None => (None, None),
@@ -1210,23 +1310,35 @@ async fn main() -> anyhow::Result<()> {
if let Some(ref origin_fp) = peer_relay_fp { if let Some(ref origin_fp) = peer_relay_fp {
if let Some(ref fm) = federation_mgr { if let Some(ref fm) = federation_mgr {
let hangup = SignalMessage::Hangup { let hangup = SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: Some(call_id.clone()), call_id: Some(call_id.clone()),
}; };
let forward = SignalMessage::FederatedSignalForward { let forward =
SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(hangup), inner: Box::new(hangup),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
if let Err(e) = fm.send_signal_to_peer(origin_fp, &forward).await { if let Err(e) = fm
.send_signal_to_peer(origin_fp, &forward)
.await
{
warn!(%call_id, %origin_fp, error = %e, "cross-relay reject forward failed"); warn!(%call_id, %origin_fp, error = %e, "cross-relay reject forward failed");
} }
} }
} else { } else {
let hub = signal_hub.lock().await; let hub = signal_hub.lock().await;
let _ = hub.send_to(&peer_fp, &SignalMessage::Hangup { let _ = hub
.send_to(
&peer_fp,
&SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: Some(call_id.clone()), call_id: Some(call_id.clone()),
}).await; },
)
.await;
} }
} else { } else {
// Accept — create private room + stash the // Accept — create private room + stash the
@@ -1236,18 +1348,36 @@ async fn main() -> anyhow::Result<()> {
// BOTH parties' addrs so we can cross-wire // BOTH parties' addrs so we can cross-wire
// peer_direct_addr on the CallSetups below. // peer_direct_addr on the CallSetups below.
let room = format!("call-{call_id}"); let room = format!("call-{call_id}");
let (caller_addr, callee_addr, caller_local, callee_local, caller_mapped, callee_mapped) = { let (
caller_addr,
callee_addr,
caller_local,
callee_local,
caller_mapped,
callee_mapped,
) = {
let mut reg = call_registry.lock().await; let mut reg = call_registry.lock().await;
reg.set_active(&call_id, mode, room.clone()); reg.set_active(&call_id, mode, room.clone());
reg.set_callee_reflexive_addr(&call_id, callee_addr_for_registry); reg.set_callee_reflexive_addr(
reg.set_callee_local_addrs(&call_id, callee_local_for_registry.clone()); &call_id,
reg.set_callee_mapped_addr(&call_id, callee_mapped_for_registry); callee_addr_for_registry,
);
reg.set_callee_local_addrs(
&call_id,
callee_local_for_registry.clone(),
);
reg.set_callee_mapped_addr(
&call_id,
callee_mapped_for_registry,
);
let call = reg.get(&call_id); let call = reg.get(&call_id);
( (
call.and_then(|c| c.caller_reflexive_addr.clone()), call.and_then(|c| c.caller_reflexive_addr.clone()),
call.and_then(|c| c.callee_reflexive_addr.clone()), call.and_then(|c| c.callee_reflexive_addr.clone()),
call.map(|c| c.caller_local_addrs.clone()).unwrap_or_default(), call.map(|c| c.caller_local_addrs.clone())
call.map(|c| c.callee_local_addrs.clone()).unwrap_or_default(), .unwrap_or_default(),
call.map(|c| c.callee_local_addrs.clone())
.unwrap_or_default(),
call.and_then(|c| c.caller_mapped_addr.clone()), call.and_then(|c| c.caller_mapped_addr.clone()),
call.and_then(|c| c.callee_mapped_addr.clone()), call.and_then(|c| c.callee_mapped_addr.clone()),
) )
@@ -1278,11 +1408,16 @@ async fn main() -> anyhow::Result<()> {
// CallSetup (to our callee) with // CallSetup (to our callee) with
// peer_direct_addr = caller_addr. // peer_direct_addr = caller_addr.
if let Some(ref fm) = federation_mgr { if let Some(ref fm) = federation_mgr {
let forward = SignalMessage::FederatedSignalForward { let forward =
SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(msg.clone()), inner: Box::new(msg.clone()),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
if let Err(e) = fm.send_signal_to_peer(origin_fp, &forward).await { if let Err(e) = fm
.send_signal_to_peer(origin_fp, &forward)
.await
{
warn!( warn!(
%call_id, %call_id,
%origin_fp, %origin_fp,
@@ -1293,6 +1428,7 @@ async fn main() -> anyhow::Result<()> {
} }
let setup_for_callee = SignalMessage::CallSetup { let setup_for_callee = SignalMessage::CallSetup {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
room: room.clone(), room: room.clone(),
relay_addr: relay_addr_for_setup, relay_addr: relay_addr_for_setup,
@@ -1301,7 +1437,8 @@ async fn main() -> anyhow::Result<()> {
peer_mapped_addr: caller_mapped.clone(), peer_mapped_addr: caller_mapped.clone(),
}; };
let hub = signal_hub.lock().await; let hub = signal_hub.lock().await;
let _ = hub.send_to(&client_fp, &setup_for_callee).await; let _ =
hub.send_to(&client_fp, &setup_for_callee).await;
} else { } else {
// Local call (existing Phase 3 path). // Local call (existing Phase 3 path).
// Forward answer to caller // Forward answer to caller
@@ -1314,6 +1451,7 @@ async fn main() -> anyhow::Result<()> {
// cross-wired candidates (Phase 5.5 ICE // cross-wired candidates (Phase 5.5 ICE
// + Phase 8 port-mapped addrs). // + Phase 8 port-mapped addrs).
let setup_for_caller = SignalMessage::CallSetup { let setup_for_caller = SignalMessage::CallSetup {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
room: room.clone(), room: room.clone(),
relay_addr: relay_addr_for_setup.clone(), relay_addr: relay_addr_for_setup.clone(),
@@ -1322,6 +1460,7 @@ async fn main() -> anyhow::Result<()> {
peer_mapped_addr: callee_mapped, peer_mapped_addr: callee_mapped,
}; };
let setup_for_callee = SignalMessage::CallSetup { let setup_for_callee = SignalMessage::CallSetup {
version: default_signal_version(),
call_id: call_id.clone(), call_id: call_id.clone(),
room: room.clone(), room: room.clone(),
relay_addr: relay_addr_for_setup, relay_addr: relay_addr_for_setup,
@@ -1331,7 +1470,8 @@ async fn main() -> anyhow::Result<()> {
}; };
let hub = signal_hub.lock().await; let hub = signal_hub.lock().await;
let _ = hub.send_to(&peer_fp, &setup_for_caller).await; let _ = hub.send_to(&peer_fp, &setup_for_caller).await;
let _ = hub.send_to(&client_fp, &setup_for_callee).await; let _ =
hub.send_to(&client_fp, &setup_for_callee).await;
} }
} }
} }
@@ -1346,21 +1486,31 @@ async fn main() -> anyhow::Result<()> {
if let Some(cid) = call_id { if let Some(cid) = call_id {
// Targeted hangup: only the named call // Targeted hangup: only the named call
reg.get(cid) reg.get(cid)
.map(|c| vec![(c.call_id.clone(), if c.caller_fingerprint == client_fp { .map(|c| {
vec![(
c.call_id.clone(),
if c.caller_fingerprint == client_fp {
c.callee_fingerprint.clone() c.callee_fingerprint.clone()
} else { } else {
c.caller_fingerprint.clone() c.caller_fingerprint.clone()
})]) },
)]
})
.unwrap_or_default() .unwrap_or_default()
} else { } else {
// Legacy: end all calls for this user // Legacy: end all calls for this user
reg.calls_for_fingerprint(&client_fp) reg.calls_for_fingerprint(&client_fp)
.iter() .iter()
.map(|c| (c.call_id.clone(), if c.caller_fingerprint == client_fp { .map(|c| {
(
c.call_id.clone(),
if c.caller_fingerprint == client_fp {
c.callee_fingerprint.clone() c.callee_fingerprint.clone()
} else { } else {
c.caller_fingerprint.clone() c.caller_fingerprint.clone()
})) },
)
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
}; };
@@ -1396,11 +1546,16 @@ async fn main() -> anyhow::Result<()> {
if let Some(ref origin_fp) = peer_relay_fp { if let Some(ref origin_fp) = peer_relay_fp {
// Cross-relay: wrap and forward // Cross-relay: wrap and forward
if let Some(ref fm) = federation_mgr { if let Some(ref fm) = federation_mgr {
let forward = SignalMessage::FederatedSignalForward { let forward =
SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(msg.clone()), inner: Box::new(msg.clone()),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
if let Err(e) = fm.send_signal_to_peer(origin_fp, &forward).await { if let Err(e) = fm
.send_signal_to_peer(origin_fp, &forward)
.await
{
warn!( warn!(
%call_id, %call_id,
%origin_fp, %origin_fp,
@@ -1436,11 +1591,16 @@ async fn main() -> anyhow::Result<()> {
if let Some(fp) = peer_fp { if let Some(fp) = peer_fp {
if let Some(ref origin_fp) = peer_relay_fp { if let Some(ref origin_fp) = peer_relay_fp {
if let Some(ref fm) = federation_mgr { if let Some(ref fm) = federation_mgr {
let forward = SignalMessage::FederatedSignalForward { let forward =
SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(msg.clone()), inner: Box::new(msg.clone()),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
if let Err(e) = fm.send_signal_to_peer(origin_fp, &forward).await { if let Err(e) = fm
.send_signal_to_peer(origin_fp, &forward)
.await
{
warn!( warn!(
%call_id, %call_id,
%origin_fp, %origin_fp,
@@ -1458,12 +1618,12 @@ async fn main() -> anyhow::Result<()> {
// Hard NAT: forward HardNatProbe + HardNatBirthdayStart // Hard NAT: forward HardNatProbe + HardNatBirthdayStart
// to call peer (same pattern as CandidateUpdate). // to call peer (same pattern as CandidateUpdate).
SignalMessage::HardNatBirthdayStart { ref call_id, .. } | SignalMessage::HardNatBirthdayStart { ref call_id, .. }
SignalMessage::HardNatProbe { ref call_id, .. } | | SignalMessage::HardNatProbe { ref call_id, .. }
SignalMessage::UpgradeProposal { ref call_id, .. } | | SignalMessage::UpgradeProposal { ref call_id, .. }
SignalMessage::UpgradeResponse { ref call_id, .. } | | SignalMessage::UpgradeResponse { ref call_id, .. }
SignalMessage::UpgradeConfirm { ref call_id, .. } | | SignalMessage::UpgradeConfirm { ref call_id, .. }
SignalMessage::QualityCapability { ref call_id, .. } => { | SignalMessage::QualityCapability { ref call_id, .. } => {
let (peer_fp, peer_relay_fp) = { let (peer_fp, peer_relay_fp) = {
let reg = call_registry.lock().await; let reg = call_registry.lock().await;
match reg.get(call_id) { match reg.get(call_id) {
@@ -1479,11 +1639,15 @@ async fn main() -> anyhow::Result<()> {
if let Some(fp) = peer_fp { if let Some(fp) = peer_fp {
if let Some(ref origin_fp) = peer_relay_fp { if let Some(ref origin_fp) = peer_relay_fp {
if let Some(ref fm) = federation_mgr { if let Some(ref fm) = federation_mgr {
let forward = SignalMessage::FederatedSignalForward { let forward =
SignalMessage::FederatedSignalForward {
version: default_signal_version(),
inner: Box::new(msg.clone()), inner: Box::new(msg.clone()),
origin_relay_fp: tls_fp.clone(), origin_relay_fp: tls_fp.clone(),
}; };
let _ = fm.send_signal_to_peer(origin_fp, &forward).await; let _ = fm
.send_signal_to_peer(origin_fp, &forward)
.await;
} }
} else { } else {
let hub = signal_hub.lock().await; let hub = signal_hub.lock().await;
@@ -1492,8 +1656,13 @@ async fn main() -> anyhow::Result<()> {
} }
} }
SignalMessage::Ping { timestamp_ms } => { SignalMessage::Ping { timestamp_ms, .. } => {
let _ = transport.send_signal(&SignalMessage::Pong { timestamp_ms }).await; let _ = transport
.send_signal(&SignalMessage::Pong {
version: default_signal_version(),
timestamp_ms,
})
.await;
} }
// QUIC-native NAT reflection ("STUN for QUIC"). // QUIC-native NAT reflection ("STUN for QUIC").
@@ -1510,11 +1679,13 @@ async fn main() -> anyhow::Result<()> {
// reaches this match arm. // reaches this match arm.
SignalMessage::Reflect => { SignalMessage::Reflect => {
let observed_addr = addr.to_string(); let observed_addr = addr.to_string();
if let Err(e) = transport.send_signal( if let Err(e) = transport
&SignalMessage::ReflectResponse { .send_signal(&SignalMessage::ReflectResponse {
version: default_signal_version(),
observed_addr: observed_addr.clone(), observed_addr: observed_addr.clone(),
}, })
).await { .await
{
warn!(%addr, error = %e, "reflect: failed to send response"); warn!(%addr, error = %e, "reflect: failed to send response");
} else { } else {
debug!(%addr, %observed_addr, "reflect: responded"); debug!(%addr, %observed_addr, "reflect: responded");
@@ -1552,19 +1723,30 @@ async fn main() -> anyhow::Result<()> {
let reg = call_registry.lock().await; let reg = call_registry.lock().await;
reg.calls_for_fingerprint(&client_fp) reg.calls_for_fingerprint(&client_fp)
.iter() .iter()
.map(|c| (c.call_id.clone(), if c.caller_fingerprint == client_fp { .map(|c| {
(
c.call_id.clone(),
if c.caller_fingerprint == client_fp {
c.callee_fingerprint.clone() c.callee_fingerprint.clone()
} else { } else {
c.caller_fingerprint.clone() c.caller_fingerprint.clone()
})) },
)
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
for (call_id, peer_fp) in &active_calls { for (call_id, peer_fp) in &active_calls {
let hub = signal_hub.lock().await; let hub = signal_hub.lock().await;
let _ = hub.send_to(peer_fp, &SignalMessage::Hangup { let _ = hub
.send_to(
peer_fp,
&SignalMessage::Hangup {
version: default_signal_version(),
reason: wzp_proto::HangupReason::Normal, reason: wzp_proto::HangupReason::Normal,
call_id: Some(call_id.clone()), call_id: Some(call_id.clone()),
}).await; },
)
.await;
drop(hub); drop(hub);
let mut reg = call_registry.lock().await; let mut reg = call_registry.lock().await;
reg.end_call(call_id); reg.end_call(call_id);
@@ -1591,7 +1773,7 @@ async fn main() -> anyhow::Result<()> {
let authenticated_fp: Option<String> = if let Some(ref url) = auth_url { let authenticated_fp: Option<String> = if let Some(ref url) = auth_url {
info!(%addr, "waiting for auth token..."); info!(%addr, "waiting for auth token...");
match transport.recv_signal().await { match transport.recv_signal().await {
Ok(Some(wzp_proto::SignalMessage::AuthToken { token })) => { Ok(Some(wzp_proto::SignalMessage::AuthToken { token, .. })) => {
match wzp_relay::auth::validate_token(url, &token).await { match wzp_relay::auth::validate_token(url, &token).await {
Ok(client) => { Ok(client) => {
metrics.auth_attempts.with_label_values(&["ok"]).inc(); metrics.auth_attempts.with_label_values(&["ok"]).inc();
@@ -1632,10 +1814,8 @@ async fn main() -> anyhow::Result<()> {
// Crypto handshake: verify client identity + negotiate quality profile // Crypto handshake: verify client identity + negotiate quality profile
let handshake_start = std::time::Instant::now(); let handshake_start = std::time::Instant::now();
let (_crypto_session, _chosen_profile, caller_fp, caller_alias) = match wzp_relay::handshake::accept_handshake( let (_crypto_session, _chosen_profile, caller_fp, caller_alias) =
&*transport, match wzp_relay::handshake::accept_handshake(&*transport, &relay_seed_bytes).await {
&relay_seed_bytes,
).await {
Ok(result) => { Ok(result) => {
let elapsed = handshake_start.elapsed().as_secs_f64(); let elapsed = handshake_start.elapsed().as_secs_f64();
metrics.handshake_duration.observe(elapsed); metrics.handshake_duration.observe(elapsed);
@@ -1704,8 +1884,18 @@ async fn main() -> anyhow::Result<()> {
} }
}); });
let up = tokio::spawn(run_upstream(transport.clone(), remote.clone(), up_pipe, stats.clone())); let up = tokio::spawn(run_upstream(
let dn = tokio::spawn(run_downstream(transport.clone(), remote.clone(), dn_pipe, stats)); transport.clone(),
remote.clone(),
up_pipe,
stats.clone(),
));
let dn = tokio::spawn(run_downstream(
transport.clone(),
remote.clone(),
dn_pipe,
stats,
));
tokio::select! { _ = up => {} _ = dn => {} } tokio::select! { _ = up => {} _ = dn => {} }
stats_handle.abort(); stats_handle.abort();
@@ -1746,33 +1936,61 @@ async fn main() -> anyhow::Result<()> {
Some(&participant_fp), Some(&participant_fp),
caller_alias.as_deref(), caller_alias.as_deref(),
) { ) {
Ok((id, update, senders)) => { Ok((id, update, senders, cached_keyframes)) => {
metrics.active_rooms.set(room_mgr.list().len() as i64); metrics.active_rooms.set(room_mgr.list().len() as i64);
// Replay cached keyframes to the new participant before live
// traffic starts. This eliminates black-screen-on-join when
// the cache is warm.
for kf in cached_keyframes {
for pkt in kf {
if let Err(e) = transport.send_media(&pkt).await {
warn!(%addr, participant = id, "keyframe replay send error: {e}");
break;
}
}
}
// Merge federated participants into RoomUpdate if this is a global room // Merge federated participants into RoomUpdate if this is a global room
let merged_update = if let Some(ref fm) = federation_mgr { let merged_update = if let Some(ref fm) = federation_mgr {
if fm.is_global_room(&room_name) { if fm.is_global_room(&room_name) {
if let SignalMessage::RoomUpdate { count: _, participants: mut local_parts } = update { if let SignalMessage::RoomUpdate {
count: _,
participants: mut local_parts,
..
} = update
{
let remote = fm.get_remote_participants(&room_name).await; let remote = fm.get_remote_participants(&room_name).await;
local_parts.extend(remote); local_parts.extend(remote);
// Deduplicate by fingerprint // Deduplicate by fingerprint
let mut seen = std::collections::HashSet::new(); let mut seen = std::collections::HashSet::new();
local_parts.retain(|p| seen.insert(p.fingerprint.clone())); local_parts.retain(|p| seen.insert(p.fingerprint.clone()));
SignalMessage::RoomUpdate { SignalMessage::RoomUpdate {
version: default_signal_version(),
count: local_parts.len() as u32, count: local_parts.len() as u32,
participants: local_parts, participants: local_parts,
} }
} else { update } } else {
} else { update } update
} else { update }; }
} else {
update
}
} else {
update
};
if let Some(ref tap) = debug_tap { if let Some(ref tap) = debug_tap {
if tap.matches(&room_name) { if tap.matches(&room_name) {
tap.log_signal(&room_name, &merged_update); tap.log_signal(&room_name, &merged_update);
tap.log_event(&room_name, "join", &format!( tap.log_event(
&room_name,
"join",
&format!(
"participant={id} addr={addr} alias={}", "participant={id} addr={addr} alias={}",
caller_alias.as_deref().unwrap_or("?") caller_alias.as_deref().unwrap_or("?")
)); ),
);
} }
} }
room::broadcast_signal(&senders, &merged_update).await; room::broadcast_signal(&senders, &merged_update).await;
@@ -1789,10 +2007,8 @@ async fn main() -> anyhow::Result<()> {
} }
}; };
let session_id_str: String = session_id let session_id_str: String =
.iter() session_id.iter().map(|b| format!("{b:02x}")).collect();
.map(|b| format!("{b:02x}"))
.collect();
// Set up federation media channel if this is a global room // Set up federation media channel if this is a global room
let (federation_tx, federation_room_hash) = if let Some(ref fm) = federation_mgr { let (federation_tx, federation_room_hash) = if let Some(ref fm) = federation_mgr {
let is_global = fm.is_global_room(&room_name); let is_global = fm.is_global_room(&room_name);
@@ -1812,18 +2028,29 @@ async fn main() -> anyhow::Result<()> {
(None, None) (None, None)
}; };
room::run_participant( let media_handle = tokio::spawn(room::run_participant(
room_mgr.clone(), room_mgr.clone(),
room_name, room_name.clone(),
participant_id, participant_id,
transport.clone(), transport.clone(),
metrics.clone(), metrics.clone(),
&session_id_str, session_id_str.clone(),
trunking_enabled, trunking_enabled,
debug_tap, debug_tap,
federation_tx, federation_tx,
federation_room_hash, federation_room_hash,
).await; authenticated_fp.is_some(),
));
let signal_handle = tokio::spawn(room::run_participant_signals(
room_mgr.clone(),
room_name.clone(),
participant_id,
transport.clone(),
));
tokio::select! {
_ = media_handle => {},
_ = signal_handle => {},
}
// Participant disconnected — clean up presence + per-session metrics // Participant disconnected — clean up presence + per-session metrics
if let Some(ref fp) = authenticated_fp { if let Some(ref fp) = authenticated_fp {

View File

@@ -1,11 +1,14 @@
//! Prometheus metrics for the WZP relay daemon. //! Prometheus metrics for the WZP relay daemon.
use prometheus::{ use prometheus::{
Encoder, GaugeVec, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Encoder, GaugeVec, Histogram, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge,
Opts, Registry, TextEncoder, IntGaugeVec, Opts, Registry, TextEncoder,
}; };
use wzp_proto::packet::QualityReport;
use std::sync::Arc; use std::sync::Arc;
use wzp_proto::MediaHeader;
use wzp_proto::packet::QualityReport;
use crate::conformance::Violation;
/// All relay-level Prometheus metrics. /// All relay-level Prometheus metrics.
#[derive(Clone)] #[derive(Clone)]
@@ -32,6 +35,9 @@ pub struct RelayMetrics {
// Phase 4: loss-recovery breakdown per session. // Phase 4: loss-recovery breakdown per session.
pub session_dred_reconstructions: IntCounterVec, pub session_dred_reconstructions: IntCounterVec,
pub session_classical_plc: IntCounterVec, pub session_classical_plc: IntCounterVec,
pub conformance_violations: IntCounterVec,
pub conformance_bytes: HistogramVec,
pub conformance_iat_ms: HistogramVec,
registry: Registry, registry: Registry,
} }
@@ -40,21 +46,23 @@ impl RelayMetrics {
pub fn new() -> Self { pub fn new() -> Self {
let registry = Registry::new(); let registry = Registry::new();
let active_sessions = IntGauge::with_opts( let active_sessions = IntGauge::with_opts(Opts::new(
Opts::new("wzp_relay_active_sessions", "Current active sessions"), "wzp_relay_active_sessions",
) "Current active sessions",
))
.expect("metric"); .expect("metric");
let active_rooms = IntGauge::with_opts( let active_rooms =
Opts::new("wzp_relay_active_rooms", "Current active rooms"), IntGauge::with_opts(Opts::new("wzp_relay_active_rooms", "Current active rooms"))
)
.expect("metric"); .expect("metric");
let packets_forwarded = IntCounter::with_opts( let packets_forwarded = IntCounter::with_opts(Opts::new(
Opts::new("wzp_relay_packets_forwarded_total", "Total packets forwarded"), "wzp_relay_packets_forwarded_total",
) "Total packets forwarded",
))
.expect("metric"); .expect("metric");
let bytes_forwarded = IntCounter::with_opts( let bytes_forwarded = IntCounter::with_opts(Opts::new(
Opts::new("wzp_relay_bytes_forwarded_total", "Total bytes forwarded"), "wzp_relay_bytes_forwarded_total",
) "Total bytes forwarded",
))
.expect("metric"); .expect("metric");
let auth_attempts = IntCounterVec::new( let auth_attempts = IntCounterVec::new(
Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"), Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"),
@@ -66,31 +74,51 @@ impl RelayMetrics {
"wzp_relay_handshake_duration_seconds", "wzp_relay_handshake_duration_seconds",
"Crypto handshake time", "Crypto handshake time",
) )
.buckets(vec![0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]), .buckets(vec![
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5,
]),
) )
.expect("metric"); .expect("metric");
let federation_peer_status = IntGaugeVec::new( let federation_peer_status = IntGaugeVec::new(
Opts::new("wzp_federation_peer_status", "Peer connection status (0=disconnected, 1=connected)"), Opts::new(
"wzp_federation_peer_status",
"Peer connection status (0=disconnected, 1=connected)",
),
&["peer"], &["peer"],
).expect("metric"); )
.expect("metric");
let federation_peer_rtt_ms = GaugeVec::new( let federation_peer_rtt_ms = GaugeVec::new(
Opts::new("wzp_federation_peer_rtt_ms", "QUIC RTT to federated peer in milliseconds"), Opts::new(
"wzp_federation_peer_rtt_ms",
"QUIC RTT to federated peer in milliseconds",
),
&["peer"], &["peer"],
).expect("metric"); )
.expect("metric");
let federation_packets_forwarded = IntCounterVec::new( let federation_packets_forwarded = IntCounterVec::new(
Opts::new("wzp_federation_packets_forwarded_total", "Packets forwarded to/from federated peers"), Opts::new(
"wzp_federation_packets_forwarded_total",
"Packets forwarded to/from federated peers",
),
&["peer", "direction"], &["peer", "direction"],
).expect("metric"); )
let federation_packets_deduped = IntCounter::with_opts( .expect("metric");
Opts::new("wzp_federation_packets_deduped_total", "Duplicate federation packets dropped"), let federation_packets_deduped = IntCounter::with_opts(Opts::new(
).expect("metric"); "wzp_federation_packets_deduped_total",
let federation_packets_rate_limited = IntCounter::with_opts( "Duplicate federation packets dropped",
Opts::new("wzp_federation_packets_rate_limited_total", "Federation packets dropped by rate limiter"), ))
).expect("metric"); .expect("metric");
let federation_active_rooms = IntGauge::with_opts( let federation_packets_rate_limited = IntCounter::with_opts(Opts::new(
Opts::new("wzp_federation_active_rooms", "Number of federated rooms currently active"), "wzp_federation_packets_rate_limited_total",
).expect("metric"); "Federation packets dropped by rate limiter",
))
.expect("metric");
let federation_active_rooms = IntGauge::with_opts(Opts::new(
"wzp_federation_active_rooms",
"Number of federated rooms currently active",
))
.expect("metric");
let session_buffer_depth = IntGaugeVec::new( let session_buffer_depth = IntGaugeVec::new(
Opts::new( Opts::new(
@@ -109,10 +137,7 @@ impl RelayMetrics {
) )
.expect("metric"); .expect("metric");
let session_rtt_ms = GaugeVec::new( let session_rtt_ms = GaugeVec::new(
Opts::new( Opts::new("wzp_relay_session_rtt_ms", "Round-trip time per session"),
"wzp_relay_session_rtt_ms",
"Round-trip time per session",
),
&["session_id"], &["session_id"],
) )
.expect("metric"); .expect("metric");
@@ -149,26 +174,104 @@ impl RelayMetrics {
&["session_id"], &["session_id"],
) )
.expect("metric"); .expect("metric");
let conformance_violations = IntCounterVec::new(
Opts::new(
"wzp_relay_conformance_violations_total",
"Conformance violations by tier, codec, media type and verdict",
),
&["tier", "codec_id", "media_type", "verdict"],
)
.expect("metric");
let conformance_bytes = HistogramVec::new(
HistogramOpts::new(
"wzp_relay_conformance_bytes_per_session",
"Packet size distribution observed by the conformance meter",
)
.buckets(vec![
16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0,
32768.0, 65536.0,
]),
&["media_type"],
)
.expect("metric");
let conformance_iat_ms = HistogramVec::new(
HistogramOpts::new(
"wzp_relay_conformance_iat_ms",
"Inter-arrival time distribution in milliseconds",
)
.buckets(vec![
1.0, 5.0, 10.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 150.0, 200.0, 300.0, 500.0,
]),
&["media_type"],
)
.expect("metric");
registry.register(Box::new(active_sessions.clone())).expect("register"); registry
registry.register(Box::new(active_rooms.clone())).expect("register"); .register(Box::new(active_sessions.clone()))
registry.register(Box::new(packets_forwarded.clone())).expect("register"); .expect("register");
registry.register(Box::new(bytes_forwarded.clone())).expect("register"); registry
registry.register(Box::new(auth_attempts.clone())).expect("register"); .register(Box::new(active_rooms.clone()))
registry.register(Box::new(handshake_duration.clone())).expect("register"); .expect("register");
registry.register(Box::new(federation_peer_status.clone())).expect("register"); registry
registry.register(Box::new(federation_peer_rtt_ms.clone())).expect("register"); .register(Box::new(packets_forwarded.clone()))
registry.register(Box::new(federation_packets_forwarded.clone())).expect("register"); .expect("register");
registry.register(Box::new(federation_packets_deduped.clone())).expect("register"); registry
registry.register(Box::new(federation_packets_rate_limited.clone())).expect("register"); .register(Box::new(bytes_forwarded.clone()))
registry.register(Box::new(federation_active_rooms.clone())).expect("register"); .expect("register");
registry.register(Box::new(session_buffer_depth.clone())).expect("register"); registry
registry.register(Box::new(session_loss_pct.clone())).expect("register"); .register(Box::new(auth_attempts.clone()))
registry.register(Box::new(session_rtt_ms.clone())).expect("register"); .expect("register");
registry.register(Box::new(session_underruns.clone())).expect("register"); registry
registry.register(Box::new(session_overruns.clone())).expect("register"); .register(Box::new(handshake_duration.clone()))
registry.register(Box::new(session_dred_reconstructions.clone())).expect("register"); .expect("register");
registry.register(Box::new(session_classical_plc.clone())).expect("register"); registry
.register(Box::new(federation_peer_status.clone()))
.expect("register");
registry
.register(Box::new(federation_peer_rtt_ms.clone()))
.expect("register");
registry
.register(Box::new(federation_packets_forwarded.clone()))
.expect("register");
registry
.register(Box::new(federation_packets_deduped.clone()))
.expect("register");
registry
.register(Box::new(federation_packets_rate_limited.clone()))
.expect("register");
registry
.register(Box::new(federation_active_rooms.clone()))
.expect("register");
registry
.register(Box::new(session_buffer_depth.clone()))
.expect("register");
registry
.register(Box::new(session_loss_pct.clone()))
.expect("register");
registry
.register(Box::new(session_rtt_ms.clone()))
.expect("register");
registry
.register(Box::new(session_underruns.clone()))
.expect("register");
registry
.register(Box::new(session_overruns.clone()))
.expect("register");
registry
.register(Box::new(session_dred_reconstructions.clone()))
.expect("register");
registry
.register(Box::new(session_classical_plc.clone()))
.expect("register");
registry
.register(Box::new(conformance_violations.clone()))
.expect("register");
registry
.register(Box::new(conformance_bytes.clone()))
.expect("register");
registry
.register(Box::new(conformance_iat_ms.clone()))
.expect("register");
Self { Self {
active_sessions, active_sessions,
@@ -190,6 +293,9 @@ impl RelayMetrics {
session_overruns, session_overruns,
session_dred_reconstructions, session_dred_reconstructions,
session_classical_plc, session_classical_plc,
conformance_violations,
conformance_bytes,
conformance_iat_ms,
registry, registry,
} }
} }
@@ -230,10 +336,7 @@ impl RelayMetrics {
.with_label_values(&[session_id]) .with_label_values(&[session_id])
.inc_by(underruns - cur_underruns as u64); .inc_by(underruns - cur_underruns as u64);
} }
let cur_overruns = self let cur_overruns = self.session_overruns.with_label_values(&[session_id]).get();
.session_overruns
.with_label_values(&[session_id])
.get();
if overruns > cur_overruns as u64 { if overruns > cur_overruns as u64 {
self.session_overruns self.session_overruns
.with_label_values(&[session_id]) .with_label_values(&[session_id])
@@ -274,6 +377,45 @@ impl RelayMetrics {
} }
} }
/// Record conformance-related metrics for a single received packet.
///
/// * `header` — the media header (provides codec_id and media_type).
/// * `payload_len` — payload length in bytes.
/// * `iat_ms` — inter-arrival time since the previous packet.
/// * `violation` — `Some(Violation)` if the packet triggered a conformance
/// limit; `None` for clean packets.
pub fn record_conformance(
&self,
header: &MediaHeader,
payload_len: usize,
iat_ms: u64,
violation: Option<Violation>,
) {
let media_type = format!("{:?}", header.media_type);
let bytes = (MediaHeader::WIRE_SIZE + payload_len) as f64;
self.conformance_bytes
.with_label_values(&[&media_type])
.observe(bytes);
self.conformance_iat_ms
.with_label_values(&[&media_type])
.observe(iat_ms as f64);
if let Some(v) = violation {
let tier = match v {
Violation::BitrateExceeded => "A",
Violation::PacketRateExceeded => "B",
Violation::TimestampDrift => "C",
Violation::PayloadSizeExceeded => "D",
Violation::RateCapExceeded => "E",
};
let codec_id = format!("{:?}", header.codec_id);
let verdict = format!("{:?}", v);
self.conformance_violations
.with_label_values(&[tier, &codec_id, &media_type, &verdict])
.inc();
}
}
/// Remove all per-session label values for a disconnected session. /// Remove all per-session label values for a disconnected session.
pub fn remove_session_metrics(&self, session_id: &str) { pub fn remove_session_metrics(&self, session_id: &str) {
let _ = self.session_buffer_depth.remove_label_values(&[session_id]); let _ = self.session_buffer_depth.remove_label_values(&[session_id]);
@@ -284,7 +426,9 @@ impl RelayMetrics {
let _ = self let _ = self
.session_dred_reconstructions .session_dred_reconstructions
.remove_label_values(&[session_id]); .remove_label_values(&[session_id]);
let _ = self.session_classical_plc.remove_label_values(&[session_id]); let _ = self
.session_classical_plc
.remove_label_values(&[session_id]);
} }
/// Get a reference to the underlying Prometheus registry. /// Get a reference to the underlying Prometheus registry.
@@ -298,7 +442,9 @@ impl RelayMetrics {
let encoder = TextEncoder::new(); let encoder = TextEncoder::new();
let metric_families = self.registry.gather(); let metric_families = self.registry.gather();
let mut buffer = Vec::new(); let mut buffer = Vec::new();
encoder.encode(&metric_families, &mut buffer).expect("encode"); encoder
.encode(&metric_families, &mut buffer)
.expect("encode");
String::from_utf8(buffer).expect("utf8") String::from_utf8(buffer).expect("utf8")
} }
} }
@@ -310,7 +456,7 @@ pub async fn serve_metrics(
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>, presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
route_resolver: Option<Arc<crate::route::RouteResolver>>, route_resolver: Option<Arc<crate::route::RouteResolver>>,
) { ) {
use axum::{extract::Path, routing::get, Router}; use axum::{Router, extract::Path, routing::get};
let metrics_clone = metrics.clone(); let metrics_clone = metrics.clone();
let presence_all = presence.clone(); let presence_all = presence.clone();

View File

@@ -11,11 +11,11 @@
use tracing::{debug, info}; use tracing::{debug, info};
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::QualityProfile;
use wzp_proto::jitter::{JitterBuffer, PlayoutResult}; use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
use wzp_proto::packet::{MediaHeader, MediaPacket}; use wzp_proto::packet::{MediaHeader, MediaPacket};
use wzp_proto::quality::AdaptiveQualityController; use wzp_proto::quality::AdaptiveQualityController;
use wzp_proto::traits::{FecDecoder, FecEncoder, QualityController}; use wzp_proto::traits::{FecDecoder, FecEncoder, QualityController};
use wzp_proto::QualityProfile;
/// Configuration for a relay pipeline instance. /// Configuration for a relay pipeline instance.
pub struct PipelineConfig { pub struct PipelineConfig {
@@ -51,7 +51,7 @@ pub struct RelayPipeline {
/// Current quality profile. /// Current quality profile.
profile: QualityProfile, profile: QualityProfile,
/// Outbound sequence counter. /// Outbound sequence counter.
out_seq: u16, out_seq: u32,
/// Packets processed count. /// Packets processed count.
stats: PipelineStats, stats: PipelineStats,
} }
@@ -111,8 +111,8 @@ impl RelayPipeline {
let header = &packet.header; let header = &packet.header;
let _ = self.fec_decoder.add_symbol( let _ = self.fec_decoder.add_symbol(
header.fec_block, header.fec_block,
header.fec_symbol, header.fec_block >> 8,
header.is_repair, header.is_repair(),
&packet.payload, &packet.payload,
); );
@@ -128,22 +128,21 @@ impl RelayPipeline {
for (i, frame) in frames.into_iter().enumerate() { for (i, frame) in frames.into_iter().enumerate() {
let reconstructed = MediaPacket { let reconstructed = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: wzp_proto::MediaType::Audio,
codec_id: header.codec_id, codec_id: header.codec_id,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: header.fec_ratio_encoded, fec_ratio: header.fec_ratio,
// Reconstruct seq from block + symbol index // Reconstruct seq from block + symbol index
seq: (header.fec_block as u16) seq: (header.fec_block as u32)
.wrapping_mul(self.profile.frames_per_block as u16) .wrapping_mul(self.profile.frames_per_block as u32)
.wrapping_add(i as u16), .wrapping_add(i as u32),
timestamp: header timestamp: header.timestamp.wrapping_add(
.timestamp (i as u32) * (header.codec_id.frame_duration_ms() as u32),
.wrapping_add((i as u32) * (header.codec_id.frame_duration_ms() as u32)), ),
fec_block: header.fec_block, fec_block: u16::from((header.fec_block & 0xFF) as u8)
fec_symbol: i as u8, | (u16::from(i as u8) << 8),
reserved: 0,
csrc_count: 0,
}, },
payload: bytes::Bytes::from(frame), payload: bytes::Bytes::from(frame),
quality_report: None, quality_report: None,
@@ -191,19 +190,16 @@ impl RelayPipeline {
for (sym_idx, repair_data) in repairs { for (sym_idx, repair_data) in repairs {
let repair_packet = MediaPacket { let repair_packet = MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: true, flags: MediaHeader::FLAG_REPAIR,
media_type: wzp_proto::MediaType::Audio,
codec_id: packet.header.codec_id, codec_id: packet.header.codec_id,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: MediaHeader::encode_fec_ratio( fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
self.profile.fec_ratio,
),
seq: self.out_seq, seq: self.out_seq,
timestamp: packet.header.timestamp, timestamp: packet.header.timestamp,
fec_block: self.fec_encoder.current_block_id(), fec_block: u16::from(self.fec_encoder.current_block_id())
fec_symbol: sym_idx, | (u16::from(sym_idx) << 8),
reserved: 0,
csrc_count: 0,
}, },
payload: bytes::Bytes::from(repair_data), payload: bytes::Bytes::from(repair_data),
quality_report: None, quality_report: None,
@@ -232,23 +228,21 @@ impl RelayPipeline {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use wzp_proto::CodecId;
use bytes::Bytes; use bytes::Bytes;
use wzp_proto::CodecId;
fn make_media_packet(seq: u16, block: u8, symbol: u8) -> MediaPacket { fn make_media_packet(seq: u32, block: u8, symbol: u8) -> MediaPacket {
MediaPacket { MediaPacket {
header: MediaHeader { header: MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: wzp_proto::MediaType::Audio,
codec_id: CodecId::Opus24k, codec_id: CodecId::Opus24k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq, seq,
timestamp: seq as u32 * 20, timestamp: seq * 20,
fec_block: block, fec_block: u16::from(block) | (u16::from(symbol) << 8),
fec_symbol: symbol,
reserved: 0,
csrc_count: 0,
}, },
payload: Bytes::from(vec![seq as u8; 60]), payload: Bytes::from(vec![seq as u8; 60]),
quality_report: None, quality_report: None,
@@ -283,7 +277,7 @@ mod tests {
// Feed 5 packets (one full block) // Feed 5 packets (one full block)
let mut total_out = 0; let mut total_out = 0;
for i in 0..5u16 { for i in 0..5u32 {
let pkt = make_media_packet(i, 0, i as u8); let pkt = make_media_packet(i, 0, i as u8);
let out = pipeline.prepare_outbound(pkt); let out = pipeline.prepare_outbound(pkt);
total_out += out.len(); total_out += out.len();

View File

@@ -63,6 +63,12 @@ pub struct PresenceRegistry {
peers: HashMap<SocketAddr, PeerRelay>, peers: HashMap<SocketAddr, PeerRelay>,
} }
impl Default for PresenceRegistry {
fn default() -> Self {
Self::new()
}
}
impl PresenceRegistry { impl PresenceRegistry {
/// Create an empty registry. /// Create an empty registry.
pub fn new() -> Self { pub fn new() -> Self {
@@ -74,13 +80,21 @@ impl PresenceRegistry {
} }
/// Register a fingerprint as locally connected (called after auth + handshake). /// Register a fingerprint as locally connected (called after auth + handshake).
pub fn register_local(&mut self, fingerprint: &str, alias: Option<String>, room: Option<String>) { pub fn register_local(
self.local.insert(fingerprint.to_string(), LocalPresence { &mut self,
fingerprint: &str,
alias: Option<String>,
room: Option<String>,
) {
self.local.insert(
fingerprint.to_string(),
LocalPresence {
fingerprint: fingerprint.to_string(), fingerprint: fingerprint.to_string(),
alias, alias,
connected_at: Instant::now(), connected_at: Instant::now(),
room, room,
}); },
);
} }
/// Unregister a locally connected fingerprint (called on disconnect). /// Unregister a locally connected fingerprint (called on disconnect).
@@ -98,11 +112,14 @@ impl PresenceRegistry {
// Insert new remote entries // Insert new remote entries
for fp in &fingerprints { for fp in &fingerprints {
self.remote.insert(fp.clone(), RemotePresence { self.remote.insert(
fp.clone(),
RemotePresence {
fingerprint: fp.clone(), fingerprint: fp.clone(),
relay_addr: addr, relay_addr: addr,
last_seen: now, last_seen: now,
}); },
);
} }
// Update the peer record // Update the peer record
@@ -156,7 +173,8 @@ impl PresenceRegistry {
self.remote.retain(|_, rp| rp.last_seen > cutoff); self.remote.retain(|_, rp| rp.last_seen > cutoff);
// Expire peer relay records and their fingerprint sets // Expire peer relay records and their fingerprint sets
let stale_peers: Vec<SocketAddr> = self.peers let stale_peers: Vec<SocketAddr> = self
.peers
.iter() .iter()
.filter(|(_, p)| p.last_update <= cutoff) .filter(|(_, p)| p.last_update <= cutoff)
.map(|(addr, _)| *addr) .map(|(addr, _)| *addr)
@@ -280,13 +298,15 @@ mod tests {
let all = reg.all_known(); let all = reg.all_known();
assert_eq!(all.len(), 2); assert_eq!(all.len(), 2);
let local_entries: Vec<_> = all.iter() let local_entries: Vec<_> = all
.iter()
.filter(|(_, loc)| *loc == PresenceLocation::Local) .filter(|(_, loc)| *loc == PresenceLocation::Local)
.collect(); .collect();
assert_eq!(local_entries.len(), 1); assert_eq!(local_entries.len(), 1);
assert_eq!(local_entries[0].0, "local1"); assert_eq!(local_entries[0].0, "local1");
let remote_entries: Vec<_> = all.iter() let remote_entries: Vec<_> = all
.iter()
.filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_))) .filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_)))
.collect(); .collect();
assert_eq!(remote_entries.len(), 1); assert_eq!(remote_entries.len(), 1);

View File

@@ -13,7 +13,7 @@ use prometheus::{Gauge, IntGauge, Opts, Registry};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use wzp_proto::{MediaTransport, SignalMessage}; use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
/// Configuration for a single probe target. /// Configuration for a single probe target.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -43,8 +43,7 @@ impl ProbeMetrics {
/// Register probe metrics with the given `target` label value. /// Register probe metrics with the given `target` label value.
pub fn register(target: &str, registry: &Registry) -> Self { pub fn register(target: &str, registry: &Registry) -> Self {
let rtt_ms = Gauge::with_opts( let rtt_ms = Gauge::with_opts(
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms") Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms").const_label("target", target),
.const_label("target", target),
) )
.expect("probe metric"); .expect("probe metric");
@@ -66,9 +65,15 @@ impl ProbeMetrics {
) )
.expect("probe metric"); .expect("probe metric");
registry.register(Box::new(rtt_ms.clone())).expect("register"); registry
registry.register(Box::new(loss_pct.clone())).expect("register"); .register(Box::new(rtt_ms.clone()))
registry.register(Box::new(jitter_ms.clone())).expect("register"); .expect("register");
registry
.register(Box::new(loss_pct.clone()))
.expect("register");
registry
.register(Box::new(jitter_ms.clone()))
.expect("register");
registry.register(Box::new(up.clone())).expect("register"); registry.register(Box::new(up.clone())).expect("register");
Self { Self {
@@ -168,7 +173,11 @@ impl ProbeRunner {
) -> Self { ) -> Self {
let target_str = config.target.to_string(); let target_str = config.target.to_string();
let metrics = ProbeMetrics::register(&target_str, registry); let metrics = ProbeMetrics::register(&target_str, registry);
Self { config, metrics, presence } Self {
config,
metrics,
presence,
}
} }
/// Run the probe forever. This function never returns under normal operation. /// Run the probe forever. This function never returns under normal operation.
@@ -198,13 +207,8 @@ impl ProbeRunner {
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap(); let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?; let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
let client_cfg = wzp_transport::client_config(); let client_cfg = wzp_transport::client_config();
let conn = wzp_transport::connect( let conn =
&endpoint, wzp_transport::connect(&endpoint, self.config.target, "_probe", client_cfg).await?;
self.config.target,
"_probe",
client_cfg,
)
.await?;
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn)); let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
self.metrics.up.set(1); self.metrics.up.set(1);
@@ -225,7 +229,7 @@ impl ProbeRunner {
let recv_handle = tokio::spawn(async move { let recv_handle = tokio::spawn(async move {
loop { loop {
match recv_transport.recv_signal().await { match recv_transport.recv_signal().await {
Ok(Some(SignalMessage::Pong { timestamp_ms })) => { Ok(Some(SignalMessage::Pong { timestamp_ms, .. })) => {
let now_ms = SystemTime::now() let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
@@ -237,11 +241,16 @@ impl ProbeRunner {
loss_gauge.set(w.loss_pct()); loss_gauge.set(w.loss_pct());
jitter_gauge.set(w.jitter_ms()); jitter_gauge.set(w.jitter_ms());
} }
Ok(Some(SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => { Ok(Some(SignalMessage::PresenceUpdate {
fingerprints,
relay_addr,
..
})) => {
if let Some(ref reg) = recv_presence { if let Some(ref reg) = recv_presence {
// Parse the relay_addr; fall back to the connection target // Parse the relay_addr; fall back to the connection target
let addr = relay_addr.parse().unwrap_or(recv_target); let addr = relay_addr.parse().unwrap_or(recv_target);
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect(); let fps: std::collections::HashSet<String> =
fingerprints.into_iter().collect();
let mut r = reg.lock().await; let mut r = reg.lock().await;
r.update_peer(addr, fps); r.update_peer(addr, fps);
} }
@@ -285,7 +294,10 @@ impl ProbeRunner {
} }
if let Err(e) = transport if let Err(e) = transport
.send_signal(&SignalMessage::Ping { timestamp_ms }) .send_signal(&SignalMessage::Ping {
version: default_signal_version(),
timestamp_ms,
})
.await .await
{ {
error!(target = %self.config.target, "probe ping send error: {e}"); error!(target = %self.config.target, "probe ping send error: {e}");
@@ -302,6 +314,7 @@ impl ProbeRunner {
r.local_fingerprints().into_iter().collect() r.local_fingerprints().into_iter().collect()
}; };
let msg = SignalMessage::PresenceUpdate { let msg = SignalMessage::PresenceUpdate {
version: default_signal_version(),
fingerprints: fps, fingerprints: fps,
relay_addr: self.config.target.to_string(), relay_addr: self.config.target.to_string(),
}; };
@@ -374,10 +387,7 @@ pub fn mesh_summary(registry: &Registry) -> String {
let name = family.get_name(); let name = family.get_name();
for metric in family.get_metric() { for metric in family.get_metric() {
// Find the "target" label // Find the "target" label
let target_label = metric let target_label = metric.get_label().iter().find(|l| l.get_name() == "target");
.get_label()
.iter()
.find(|l| l.get_name() == "target");
let target = match target_label { let target = match target_label {
Some(l) => l.get_value().to_string(), Some(l) => l.get_value().to_string(),
None => continue, None => continue,
@@ -420,13 +430,11 @@ pub fn mesh_summary(registry: &Registry) -> String {
/// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp. /// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp.
/// Returns true if the message was a Ping and was handled, false otherwise. /// Returns true if the message was a Ping and was handled, false otherwise.
pub async fn handle_ping( pub async fn handle_ping(transport: &wzp_transport::QuinnTransport, msg: &SignalMessage) -> bool {
transport: &wzp_transport::QuinnTransport, if let SignalMessage::Ping { timestamp_ms, .. } = msg {
msg: &SignalMessage,
) -> bool {
if let SignalMessage::Ping { timestamp_ms } = msg {
if let Err(e) = transport if let Err(e) = transport
.send_signal(&SignalMessage::Pong { .send_signal(&SignalMessage::Pong {
version: default_signal_version(),
timestamp_ms: *timestamp_ms, timestamp_ms: *timestamp_ms,
}) })
.await .await
@@ -456,9 +464,18 @@ mod tests {
encoder.encode(&families, &mut buf).unwrap(); encoder.encode(&families, &mut buf).unwrap();
let output = String::from_utf8(buf).unwrap(); let output = String::from_utf8(buf).unwrap();
assert!(output.contains("wzp_probe_rtt_ms"), "missing wzp_probe_rtt_ms"); assert!(
assert!(output.contains("wzp_probe_loss_pct"), "missing wzp_probe_loss_pct"); output.contains("wzp_probe_rtt_ms"),
assert!(output.contains("wzp_probe_jitter_ms"), "missing wzp_probe_jitter_ms"); "missing wzp_probe_rtt_ms"
);
assert!(
output.contains("wzp_probe_loss_pct"),
"missing wzp_probe_loss_pct"
);
assert!(
output.contains("wzp_probe_jitter_ms"),
"missing wzp_probe_jitter_ms"
);
assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up"); assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up");
assert!( assert!(
output.contains("target=\"127.0.0.1:4433\""), output.contains("target=\"127.0.0.1:4433\""),

View File

@@ -40,10 +40,7 @@ impl RelayLink {
/// should skip normal client auth/handshake for relay-SNI connections. /// should skip normal client auth/handshake for relay-SNI connections.
pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> { pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> {
// Create a client-only endpoint on an OS-assigned port. // Create a client-only endpoint on an OS-assigned port.
let endpoint = wzp_transport::create_endpoint( let endpoint = wzp_transport::create_endpoint("0.0.0.0:0".parse().unwrap(), None)?;
"0.0.0.0:0".parse().unwrap(),
None,
)?;
let client_cfg = wzp_transport::client_config(); let client_cfg = wzp_transport::client_config();
let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?; let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?;
@@ -336,10 +333,11 @@ mod tests {
#[test] #[test]
fn session_forward_signal_roundtrip() { fn session_forward_signal_roundtrip() {
use wzp_proto::SignalMessage; use wzp_proto::{SignalMessage, default_signal_version};
// SessionForward roundtrip // SessionForward roundtrip
let msg = SignalMessage::SessionForward { let msg = SignalMessage::SessionForward {
version: default_signal_version(),
session_id: "abcd1234".to_string(), session_id: "abcd1234".to_string(),
target_fingerprint: "deadbeef".to_string(), target_fingerprint: "deadbeef".to_string(),
source_relay: "10.0.0.1:4433".to_string(), source_relay: "10.0.0.1:4433".to_string(),
@@ -351,6 +349,7 @@ mod tests {
session_id, session_id,
target_fingerprint, target_fingerprint,
source_relay, source_relay,
..
} => { } => {
assert_eq!(session_id, "abcd1234"); assert_eq!(session_id, "abcd1234");
assert_eq!(target_fingerprint, "deadbeef"); assert_eq!(target_fingerprint, "deadbeef");
@@ -361,6 +360,7 @@ mod tests {
// SessionForwardAck roundtrip // SessionForwardAck roundtrip
let ack = SignalMessage::SessionForwardAck { let ack = SignalMessage::SessionForwardAck {
version: default_signal_version(),
session_id: "abcd1234".to_string(), session_id: "abcd1234".to_string(),
room_name: "relay-room-42".to_string(), room_name: "relay-room-42".to_string(),
}; };
@@ -370,6 +370,7 @@ mod tests {
SignalMessage::SessionForwardAck { SignalMessage::SessionForwardAck {
session_id, session_id,
room_name, room_name,
..
} => { } => {
assert_eq!(session_id, "abcd1234"); assert_eq!(session_id, "abcd1234");
assert_eq!(room_name, "relay-room-42"); assert_eq!(room_name, "relay-room-42");
@@ -457,17 +458,15 @@ mod tests {
let pkt = MediaPacket { let pkt = MediaPacket {
header: wzp_proto::packet::MediaHeader { header: wzp_proto::packet::MediaHeader {
version: 0, version: 2,
is_repair: false, flags: 0,
media_type: wzp_proto::MediaType::Audio,
codec_id: wzp_proto::CodecId::Opus16k, codec_id: wzp_proto::CodecId::Opus16k,
has_quality_report: false, stream_id: 0,
fec_ratio_encoded: 0, fec_ratio: 0,
seq: 1, seq: 1,
timestamp: 100, timestamp: 100,
fec_block: 0, fec_block: 0,
fec_symbol: 0,
reserved: 0,
csrc_count: 0,
}, },
payload: bytes::Bytes::from_static(b"test"), payload: bytes::Bytes::from_static(b"test"),
quality_report: None, quality_report: None,

View File

@@ -0,0 +1,207 @@
//! Tier G response policy — maps conformance verdicts to enforcement actions.
//!
//! Actions:
//! - `Legitimate` → no action
//! - `Suspect` → tighten Tier E quota, emit metric
//! - `Abusive` → typed Hangup + 1 h fingerprint cool-down
//! - `RepeatAbusive` → relay-local block 24 h
use std::collections::HashMap;
use std::time::{Duration, Instant};
use wzp_proto::packet::{HangupReason, ViolationCode};
use crate::verdict::Verdict;
/// Enforcement action recommended by the response policy.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Action {
/// Pass through unchanged.
Allow,
/// Throttle to tighter quota (Tier E).
Throttle,
/// Close the session with a typed Hangup signal.
Close { reason: HangupReason },
/// Block the fingerprint from joining any room for 24 h.
Block,
}
/// Tracks fingerprint-level abuse history and applies escalation.
pub struct ResponsePolicy {
/// `(fingerprint, violation_code)` → last abusive instant.
cooldowns: HashMap<(String, ViolationCode), Instant>,
/// Block duration for repeat abuse.
block_duration: Duration,
}
impl ResponsePolicy {
pub fn new() -> Self {
Self {
cooldowns: HashMap::new(),
block_duration: Duration::from_secs(86400), // 24 h
}
}
/// Evaluate a verdict and produce the corresponding [`Action`].
///
/// `fingerprint` is the participant's identity string (or IP as fallback).
/// `code` is the specific violation type that triggered the verdict.
pub fn evaluate(&mut self, fingerprint: &str, code: ViolationCode, verdict: Verdict) -> Action {
match verdict {
Verdict::Legitimate => Action::Allow,
Verdict::Suspect => Action::Throttle,
Verdict::Abusive => {
let key = (fingerprint.to_string(), code);
let now = Instant::now();
// Check if this fingerprint was already abusive recently.
let is_repeat = self
.cooldowns
.get(&key)
.map(|last| now.duration_since(*last) < self.block_duration)
.unwrap_or(false);
if is_repeat {
Action::Block
} else {
self.cooldowns.insert(key, now);
Action::Close {
reason: HangupReason::PolicyViolation {
code,
reason: format!("Tier G enforcement: {code:?}"),
},
}
}
}
}
}
/// Returns true if the fingerprint is currently blocked (repeat abuse).
pub fn is_blocked(&self, fingerprint: &str) -> bool {
let now = Instant::now();
self.cooldowns.iter().any(|((fp, _), last)| {
fp == fingerprint && now.duration_since(*last) < self.block_duration
})
}
/// Clean up expired cooldown entries.
pub fn prune(&mut self) {
let now = Instant::now();
self.cooldowns
.retain(|_, last| now.duration_since(*last) < self.block_duration);
}
/// Number of tracked cooldown entries.
pub fn len(&self) -> usize {
self.cooldowns.len()
}
pub fn is_empty(&self) -> bool {
self.cooldowns.is_empty()
}
}
impl Default for ResponsePolicy {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn legitimate_allowed() {
let mut policy = ResponsePolicy::new();
assert_eq!(
policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Legitimate),
Action::Allow
);
}
#[test]
fn suspect_throttled() {
let mut policy = ResponsePolicy::new();
assert_eq!(
policy.evaluate("alice", ViolationCode::Entropy, Verdict::Suspect),
Action::Throttle
);
}
#[test]
fn abusive_gets_close() {
let mut policy = ResponsePolicy::new();
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
assert!(
matches!(action, Action::Close { .. }),
"first-time abuse should close session"
);
}
#[test]
fn repeat_abusive_gets_block() {
let mut policy = ResponsePolicy::new();
// First abuse
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
// Second abuse within window → block
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
assert_eq!(action, Action::Block, "repeat abuse should block");
}
#[test]
fn different_violation_codes_are_independent() {
let mut policy = ResponsePolicy::new();
// Abuse on bitrate
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
// Abuse on entropy is treated as first-time for that code
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
assert!(
matches!(action, Action::Close { .. }),
"different violation code should not trigger repeat"
);
}
#[test]
fn is_blocked_true_after_repeat() {
let mut policy = ResponsePolicy::new();
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
assert!(policy.is_blocked("alice"));
}
#[test]
fn is_blocked_false_for_legitimate() {
let policy = ResponsePolicy::new();
assert!(!policy.is_blocked("alice"));
}
#[test]
fn prune_removes_expired() {
let mut policy = ResponsePolicy::new();
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
assert_eq!(policy.len(), 1);
// Manually expire by moving cooldown back
policy.cooldowns.insert(
("alice".to_string(), ViolationCode::Bitrate),
Instant::now() - Duration::from_secs(90000),
);
policy.prune();
assert!(policy.is_empty());
}
#[test]
fn close_reason_contains_code() {
let mut policy = ResponsePolicy::new();
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
match action {
Action::Close { reason } => match reason {
HangupReason::PolicyViolation { code, .. } => {
assert_eq!(code, ViolationCode::Entropy);
}
other => panic!("expected PolicyViolation, got {other:?}"),
},
other => panic!("expected Close, got {other:?}"),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -97,14 +97,13 @@ impl RouteResolver {
} }
/// Build a JSON-serializable route response for the HTTP API. /// Build a JSON-serializable route response for the HTTP API.
pub fn route_json( pub fn route_json(&self, fingerprint: &str, route: &Route) -> serde_json::Value {
&self,
fingerprint: &str,
route: &Route,
) -> serde_json::Value {
let (route_type, relay_chain) = match route { let (route_type, relay_chain) = match route {
Route::Local => ("local", vec![self.local_addr.to_string()]), Route::Local => ("local", vec![self.local_addr.to_string()]),
Route::DirectPeer(addr) => ("direct_peer", vec![self.local_addr.to_string(), addr.to_string()]), Route::DirectPeer(addr) => (
"direct_peer",
vec![self.local_addr.to_string(), addr.to_string()],
),
Route::Chain(chain) => { Route::Chain(chain) => {
let mut addrs = vec![self.local_addr.to_string()]; let mut addrs = vec![self.local_addr.to_string()];
addrs.extend(chain.iter().map(|a| a.to_string())); addrs.extend(chain.iter().map(|a| a.to_string()));
@@ -184,7 +183,10 @@ mod tests {
reg.update_peer(peer, fps); reg.update_peer(peer, fps);
// Local lookup works via multi-hop // Local lookup works via multi-hop
assert_eq!(resolver.resolve_multi_hop(&reg, "local_fp", 3), Route::Local); assert_eq!(
resolver.resolve_multi_hop(&reg, "local_fp", 3),
Route::Local
);
// Remote lookup works via multi-hop // Remote lookup works via multi-hop
assert_eq!( assert_eq!(
resolver.resolve_multi_hop(&reg, "remote_fp", 3), resolver.resolve_multi_hop(&reg, "remote_fp", 3),
@@ -199,9 +201,10 @@ mod tests {
#[test] #[test]
fn route_query_signal_roundtrip() { fn route_query_signal_roundtrip() {
use wzp_proto::SignalMessage; use wzp_proto::{SignalMessage, default_signal_version};
let query = SignalMessage::RouteQuery { let query = SignalMessage::RouteQuery {
version: default_signal_version(),
fingerprint: "aabbccdd".to_string(), fingerprint: "aabbccdd".to_string(),
ttl: 3, ttl: 3,
}; };
@@ -209,11 +212,12 @@ mod tests {
let decoded: SignalMessage = serde_json::from_str(&json).unwrap(); let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
assert!(matches!( assert!(matches!(
decoded, decoded,
SignalMessage::RouteQuery { ref fingerprint, ttl } SignalMessage::RouteQuery { ref fingerprint, ttl, ..}
if fingerprint == "aabbccdd" && ttl == 3 if fingerprint == "aabbccdd" && ttl == 3
)); ));
let response = SignalMessage::RouteResponse { let response = SignalMessage::RouteResponse {
version: default_signal_version(),
fingerprint: "aabbccdd".to_string(), fingerprint: "aabbccdd".to_string(),
found: true, found: true,
relay_chain: vec!["10.0.0.1:4433".to_string(), "10.0.0.2:4433".to_string()], relay_chain: vec!["10.0.0.1:4433".to_string(), "10.0.0.2:4433".to_string()],
@@ -222,7 +226,7 @@ mod tests {
let decoded: SignalMessage = serde_json::from_str(&json).unwrap(); let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
assert!(matches!( assert!(matches!(
decoded, decoded,
SignalMessage::RouteResponse { ref fingerprint, found, ref relay_chain } SignalMessage::RouteResponse { ref fingerprint, found, ref relay_chain, ..}
if fingerprint == "aabbccdd" && found && relay_chain.len() == 2 if fingerprint == "aabbccdd" && found && relay_chain.len() == 2
)); ));
} }

View File

@@ -143,18 +143,18 @@ impl SessionManager {
fingerprint: Option<String>, fingerprint: Option<String>,
) -> Result<SessionId, String> { ) -> Result<SessionId, String> {
if self.total_count() >= self.max_sessions { if self.total_count() >= self.max_sessions {
return Err(format!( return Err(format!("max sessions ({}) exceeded", self.max_sessions));
"max sessions ({}) exceeded",
self.max_sessions
));
} }
let id = rand_session_id(); let id = rand_session_id();
self.tracked.insert(id, SessionInfo { self.tracked.insert(
id,
SessionInfo {
room_name: room.to_string(), room_name: room.to_string(),
fingerprint, fingerprint,
connected_at: Instant::now(), connected_at: Instant::now(),
state: SessionState::Active, state: SessionState::Active,
}); },
);
Ok(id) Ok(id)
} }
@@ -165,7 +165,10 @@ impl SessionManager {
/// Number of currently tracked (room-mode) sessions. /// Number of currently tracked (room-mode) sessions.
pub fn active_count(&self) -> usize { pub fn active_count(&self) -> usize {
self.tracked.values().filter(|s| s.state == SessionState::Active).count() self.tracked
.values()
.filter(|s| s.state == SessionState::Active)
.count()
} }
/// Return all session IDs that belong to a given room. /// Return all session IDs that belong to a given room.
@@ -278,7 +281,9 @@ mod tests {
#[test] #[test]
fn session_info_returns_correct_data() { fn session_info_returns_correct_data() {
let mut mgr = SessionManager::new(10); let mut mgr = SessionManager::new(10);
let id = mgr.create_session("room-x", Some("alice-fp".into())).unwrap(); let id = mgr
.create_session("room-x", Some("alice-fp".into()))
.unwrap();
let info = mgr.session_info(id).expect("session should exist"); let info = mgr.session_info(id).expect("session should exist");
assert_eq!(info.room_name, "room-x"); assert_eq!(info.room_name, "room-x");
@@ -297,6 +302,9 @@ mod tests {
mgr.create_session("room", None).unwrap(); mgr.create_session("room", None).unwrap();
// Both layers should now reject // Both layers should now reject
assert!(mgr.create_session("room", None).is_err()); assert!(mgr.create_session("room", None).is_err());
assert!(mgr.create_pipeline_session([2u8; 16], PipelineConfig::default()).is_none()); assert!(
mgr.create_pipeline_session([2u8; 16], PipelineConfig::default())
.is_none()
);
} }
} }

View File

@@ -8,7 +8,7 @@ use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use tracing::info; use tracing::info;
use wzp_proto::{MediaTransport, SignalMessage}; use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
use wzp_transport::QuinnTransport; use wzp_transport::QuinnTransport;
/// A client connected via `_signal` for direct calling. /// A client connected via `_signal` for direct calling.
@@ -34,12 +34,15 @@ impl SignalHub {
/// Register a new signaling client. /// Register a new signaling client.
pub fn register(&mut self, fp: String, transport: Arc<QuinnTransport>, alias: Option<String>) { pub fn register(&mut self, fp: String, transport: Arc<QuinnTransport>, alias: Option<String>) {
info!(fingerprint = %fp, alias = ?alias, "signal client registered"); info!(fingerprint = %fp, alias = ?alias, "signal client registered");
self.clients.insert(fp.clone(), SignalClient { self.clients.insert(
fp.clone(),
SignalClient {
fingerprint: fp, fingerprint: fp,
alias, alias,
transport, transport,
connected_at: Instant::now(), connected_at: Instant::now(),
}); },
);
} }
/// Unregister a signaling client. Returns the client if found. /// Unregister a signaling client. Returns the client if found.
@@ -64,10 +67,11 @@ impl SignalHub {
/// Send a signal message to a client by fingerprint. /// Send a signal message to a client by fingerprint.
pub async fn send_to(&self, fp: &str, msg: &SignalMessage) -> Result<(), String> { pub async fn send_to(&self, fp: &str, msg: &SignalMessage) -> Result<(), String> {
match self.clients.get(fp) { match self.clients.get(fp) {
Some(client) => { Some(client) => client
client.transport.send_signal(msg).await .transport
.map_err(|e| format!("send to {fp}: {e}")) .send_signal(msg)
} .await
.map_err(|e| format!("send to {fp}: {e}")),
None => Err(format!("{fp} not online")), None => Err(format!("{fp} not online")),
} }
} }
@@ -97,7 +101,10 @@ impl SignalHub {
alias: c.alias.clone(), alias: c.alias.clone(),
}) })
.collect(); .collect();
SignalMessage::PresenceList { users } SignalMessage::PresenceList {
version: default_signal_version(),
users,
}
} }
/// Broadcast a message to ALL connected signal clients. /// Broadcast a message to ALL connected signal clients.

Some files were not shown because too many files have changed in this diff Show More