Compare commits
195 Commits
opus-DRED-
...
video-usab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12020b019c | ||
|
|
3ea25a0656 | ||
|
|
112472609e | ||
|
|
9a7745978b | ||
|
|
f85efb9576 | ||
|
|
31b2caa54d | ||
|
|
079e21e174 | ||
|
|
e676641538 | ||
|
|
9713efc404 | ||
|
|
8415804a1a | ||
|
|
f65b399a21 | ||
|
|
3437a6bd11 | ||
|
|
15eb00ed5e | ||
|
|
0c2297a2b7 | ||
|
|
a08a37b5eb | ||
|
|
f6ace54556 | ||
|
|
47baa1a765 | ||
|
|
ee654cd1ef | ||
|
|
d2046060b5 | ||
|
|
0b7bf1b385 | ||
|
|
e8f139588a | ||
|
|
0115b11de7 | ||
|
|
fa812a17d9 | ||
|
|
8d6b168f1b | ||
|
|
ca164ada5c | ||
|
|
2d58bae9ba | ||
|
|
e1ca6ca6e6 | ||
|
|
06d28a9280 | ||
|
|
d57ebe3d2c | ||
|
|
7eca79846f | ||
|
|
25b3278d31 | ||
|
|
cbc3a8d37e | ||
|
|
1329abbeba | ||
|
|
e8cab25eda | ||
|
|
c41ced53e1 | ||
|
|
7fd66be6c8 | ||
|
|
8002acaf09 | ||
|
|
06253fdeeb | ||
|
|
01f55caa96 | ||
|
|
0f93a2b745 | ||
|
|
2b93bd4b45 | ||
|
|
bc021517c0 | ||
|
|
739bdaf3ab | ||
|
|
bc1668ed96 | ||
|
|
77b036439b | ||
|
|
0ebc73ab13 | ||
|
|
394987a349 | ||
|
|
2aa6582585 | ||
|
|
ca987d547c | ||
|
|
5a13f12334 | ||
|
|
b0a3b1f18e | ||
|
|
32c07d1b61 | ||
|
|
5d05b021aa | ||
|
|
4ac62d99e0 | ||
|
|
4ebb2dac2d | ||
|
|
52a6f5e048 | ||
|
|
15af58a95d | ||
|
|
ed8a7ae5aa | ||
|
|
12b0d9738f | ||
|
|
f78794f4b6 | ||
|
|
f3e3ee5ed0 | ||
|
|
f28f39d814 | ||
|
|
1e729e4b1d | ||
|
|
086d0a4845 | ||
|
|
9334aa5ccd | ||
|
|
553c8a4ce1 | ||
|
|
8d8dddbd35 | ||
|
|
f16d650721 | ||
|
|
31f2fdef1e | ||
|
|
fc9908cd4c | ||
|
|
517d0ebfe0 | ||
|
|
cf4940417e | ||
|
|
ffded2a913 | ||
|
|
283edd38eb | ||
|
|
fdfaed5390 | ||
|
|
dbbab0decf | ||
|
|
5fda5ecc52 | ||
|
|
2bbb664df4 | ||
|
|
2f1a9f74d5 | ||
|
|
b197651557 | ||
|
|
9c41d1acdd | ||
|
|
e34c40dc0f | ||
|
|
c48cb6fbcb | ||
|
|
2e0bdc5904 | ||
|
|
276ecc660e | ||
|
|
001d94f9ae | ||
|
|
36b0421d68 | ||
|
|
828fbea2ea | ||
|
|
cc5aef2534 | ||
|
|
397f9d2141 | ||
|
|
410c2a4335 | ||
|
|
81042ac190 | ||
|
|
e177e63843 | ||
|
|
1f7d130de9 | ||
|
|
3356ba94c6 | ||
|
|
bb153a331d | ||
|
|
490d2d31c6 | ||
|
|
db69f7e9d1 | ||
|
|
f1b86e0fed | ||
|
|
8454835c18 | ||
|
|
017c371611 | ||
|
|
3220bd6151 | ||
|
|
e73f8a7150 | ||
|
|
1b4f7b0772 | ||
|
|
f3398adb95 | ||
|
|
54c1a35186 | ||
|
|
3de56cf1f9 | ||
|
|
fe1f9484bd | ||
|
|
0ef1f574ff | ||
|
|
b1c5837495 | ||
|
|
6f81487778 | ||
|
|
5cdb50160a | ||
|
|
30d26fc7f6 | ||
|
|
c93d302656 | ||
|
|
9680b6ff34 | ||
|
|
6b15b8f97c | ||
|
|
6385b93391 | ||
|
|
6eb94f079d | ||
|
|
5580b794a4 | ||
|
|
7c9ede9227 | ||
|
|
e8866c6632 | ||
|
|
8c6e88ea68 | ||
|
|
ffb92237be | ||
|
|
6af0539a72 | ||
|
|
217567383d | ||
|
|
98ed981805 | ||
|
|
01a3133544 | ||
|
|
25471c694f | ||
|
|
a058a83c91 | ||
|
|
9b8013ba7f | ||
|
|
defd8eab07 | ||
|
|
cc23e829b2 | ||
|
|
18c204c1ff | ||
|
|
1120c7b579 | ||
|
|
7e7391fdbb | ||
|
|
aa0362f318 | ||
|
|
bb23976076 | ||
|
|
18e5e75f33 | ||
|
|
488efcb614 | ||
|
|
8c360186df | ||
|
|
f06f9073ae | ||
|
|
6c49d7436f | ||
|
|
1de280fe04 | ||
|
|
bc6d327ebb | ||
|
|
c478224d67 | ||
|
|
16dcc75514 | ||
|
|
db5751985e | ||
|
|
c0dd6c06ff | ||
|
|
6805caae0e | ||
|
|
5a03da72d3 | ||
|
|
e3e63a40a0 | ||
|
|
7b4bce69d5 | ||
|
|
ec1bdf3cd5 | ||
|
|
ee14862376 | ||
|
|
f83361895e | ||
|
|
0857d190ed | ||
|
|
5d431c0721 | ||
|
|
8fcf1be341 | ||
|
|
9377a9009c | ||
|
|
4471797edf | ||
|
|
425c67a08a | ||
|
|
88ca3e099a | ||
|
|
1e82811cc1 | ||
|
|
81b5522942 | ||
|
|
d539a6dfb9 | ||
|
|
ba12aae439 | ||
|
|
fdb78e08bd | ||
|
|
3a51db998a | ||
|
|
a52b011fb5 | ||
|
|
2514151a89 | ||
|
|
f265fd772d | ||
|
|
9ae9441de4 | ||
|
|
d9e7e72978 | ||
|
|
8ff0c548a7 | ||
|
|
f17420aa98 | ||
|
|
d424515542 | ||
|
|
ea5fc17c34 | ||
|
|
1a7dd935ee | ||
|
|
a7c2261b70 | ||
|
|
eca0bb7531 | ||
|
|
6f43415285 | ||
|
|
d36feb2b59 | ||
|
|
baf82d935b | ||
|
|
6eb10327c1 | ||
|
|
50339542fa | ||
|
|
c67fa18f14 | ||
|
|
6c5c4cb671 | ||
|
|
8816f13df8 | ||
|
|
3804b0bf46 | ||
|
|
234f3c4bfe | ||
|
|
e97f278390 | ||
|
|
f6a77da948 | ||
|
|
82015a78af | ||
|
|
cb13af8abd | ||
|
|
0b8276b9c7 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -12,6 +12,11 @@ npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
dev-debug.log
|
||||
|
||||
# Debug frame dump artifacts
|
||||
android-frame-dumps/
|
||||
wzp-frame-dumps.tar
|
||||
|
||||
# Dependency directories
|
||||
node_modules/
|
||||
# Environment variables
|
||||
|
||||
14
.gitleaks.toml
Normal file
14
.gitleaks.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[[allowlists]]
|
||||
description = "Pre-existing historical findings already on fj/main and github/main. The two PASTE_AUTH tokens in scripts/build.sh and scripts/build-linux-notify.sh are real — rotate if those endpoints still authenticate; this allowlist only silences the pre-push hook, it does not remove the exposure."
|
||||
commits = [
|
||||
# wzp-crypto module doc: false positive on "SHA-256(Ed25519 pub)[:16]"
|
||||
"51e893590c1b9fa49e9f6ae5c96c26deb58f353b",
|
||||
# build.sh PASTE_AUTH (paste.tbs.amn.gg)
|
||||
"bd6733b2e5d76b5259020f1c30a5223a9773b6aa",
|
||||
# build-linux-notify Authorization header (paste.dk.manko.yoga)
|
||||
"6d776097c83bc6fbe3f3565e080513d8af93b550",
|
||||
"7751439e2bca9eacf2c30929c8124a4eb6136df2",
|
||||
]
|
||||
1523
Cargo.lock
generated
1523
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ members = [
|
||||
"crates/wzp-web",
|
||||
"crates/wzp-android",
|
||||
"crates/wzp-native",
|
||||
"crates/wzp-video",
|
||||
"desktop/src-tauri",
|
||||
]
|
||||
|
||||
|
||||
1
android.sh
Normal file
1
android.sh
Normal file
@@ -0,0 +1 @@
|
||||
./scripts/android-build-async.sh --init
|
||||
@@ -28,6 +28,7 @@ libc = "0.2"
|
||||
jni = { version = "0.21", default-features = false }
|
||||
rand = { workspace = true }
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring"] }
|
||||
[target.'cfg(target_os = "android")'.dependencies]
|
||||
tracing-android = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
|
||||
@@ -65,9 +65,8 @@ fn main() {
|
||||
} else {
|
||||
"aarch64-linux-android"
|
||||
};
|
||||
let lib_dir = format!(
|
||||
"{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}"
|
||||
);
|
||||
let lib_dir =
|
||||
format!("{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}");
|
||||
println!("cargo:rustc-link-search=native={lib_dir}");
|
||||
|
||||
// Copy libc++_shared.so to the jniLibs directory
|
||||
@@ -82,9 +81,7 @@ fn main() {
|
||||
};
|
||||
// Try to copy to the Gradle jniLibs directory
|
||||
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
|
||||
let jni_dir = format!(
|
||||
"{manifest}/../../android/app/src/main/jniLibs/{jni_abi}"
|
||||
);
|
||||
let jni_dir = format!("{manifest}/../../android/app/src/main/jniLibs/{jni_abi}");
|
||||
if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
|
||||
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
|
||||
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
|
||||
@@ -127,7 +124,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
let oboe_dir = out_dir.join("oboe");
|
||||
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
return Some(oboe_dir);
|
||||
}
|
||||
|
||||
@@ -143,7 +145,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
|
||||
match status {
|
||||
Ok(s) if s.success() => {
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
Some(oboe_dir)
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -326,7 +326,10 @@ pub fn pin_to_big_core() {
|
||||
&set,
|
||||
);
|
||||
if ret != 0 {
|
||||
warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error());
|
||||
warn!(
|
||||
"sched_setaffinity failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
} else {
|
||||
info!(start, num_cpus, "pinned to big cores");
|
||||
}
|
||||
|
||||
@@ -77,7 +77,8 @@ impl AudioRing {
|
||||
}
|
||||
}
|
||||
|
||||
self.write_pos.store(w.wrapping_add(count), Ordering::Release);
|
||||
self.write_pos
|
||||
.store(w.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
@@ -112,7 +113,8 @@ impl AudioRing {
|
||||
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
||||
}
|
||||
|
||||
self.read_pos.store(r.wrapping_add(count), Ordering::Release);
|
||||
self.read_pos
|
||||
.store(r.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::{
|
||||
AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
|
||||
MediaHeader, MediaPacket, MediaTransport, QualityController, QualityProfile, SignalMessage,
|
||||
MediaHeader, MediaPacket, MediaTransport, MediaType, QualityController, QualityProfile,
|
||||
SignalMessage, default_signal_version,
|
||||
};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
@@ -46,7 +47,11 @@ const PROFILES: [QualityProfile; 6] = [
|
||||
];
|
||||
|
||||
fn profile_to_index(p: &QualityProfile) -> u8 {
|
||||
PROFILES.iter().position(|pp| pp.codec == p.codec).map(|i| i as u8).unwrap_or(3)
|
||||
PROFILES
|
||||
.iter()
|
||||
.position(|pp| pp.codec == p.codec)
|
||||
.map(|i| i as u8)
|
||||
.unwrap_or(3)
|
||||
}
|
||||
|
||||
fn index_to_profile(idx: u8) -> Option<QualityProfile> {
|
||||
@@ -149,9 +154,10 @@ impl WzpEngine {
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let relay_addr: SocketAddr = config.relay_addr.parse().map_err(|e| {
|
||||
anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr)
|
||||
})?;
|
||||
let relay_addr: SocketAddr = config
|
||||
.relay_addr
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr))?;
|
||||
|
||||
let room = config.room.clone();
|
||||
let identity_seed = config.identity_seed;
|
||||
@@ -165,7 +171,16 @@ impl WzpEngine {
|
||||
|
||||
let state_clone = state.clone();
|
||||
runtime.block_on(async move {
|
||||
if let Err(e) = run_call(relay_addr, &room, &identity_seed, profile, auto_profile, alias.as_deref(), state_clone).await
|
||||
if let Err(e) = run_call(
|
||||
relay_addr,
|
||||
&room,
|
||||
&identity_seed,
|
||||
profile,
|
||||
auto_profile,
|
||||
alias.as_deref(),
|
||||
state_clone,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("call failed: {e}");
|
||||
}
|
||||
@@ -233,16 +248,21 @@ impl WzpEngine {
|
||||
let server_fp = conn
|
||||
.peer_identity()
|
||||
.and_then(|id| id.downcast::<Vec<rustls::pki_types::CertificateDer>>().ok())
|
||||
.and_then(|certs| certs.first().map(|c| {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut h = std::collections::hash_map::DefaultHasher::new();
|
||||
c.as_ref().hash(&mut h);
|
||||
format!("{:016x}", h.finish())
|
||||
}))
|
||||
.and_then(|certs| {
|
||||
certs.first().map(|c| {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut h = std::collections::hash_map::DefaultHasher::new();
|
||||
c.as_ref().hash(&mut h);
|
||||
format!("{:016x}", h.finish())
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
conn.close(0u32.into(), b"ping");
|
||||
|
||||
Ok::<_, anyhow::Error>(format!(r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#, rtt_ms, server_fp))
|
||||
Ok::<_, anyhow::Error>(format!(
|
||||
r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#,
|
||||
rtt_ms, server_fp
|
||||
))
|
||||
});
|
||||
|
||||
// Shutdown runtime cleanly with timeout
|
||||
@@ -301,11 +321,12 @@ impl WzpEngine {
|
||||
|
||||
// Auth if token provided
|
||||
if let Some(ref tok) = token {
|
||||
let _ = transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await;
|
||||
let _ = transport.send_signal(&SignalMessage::AuthToken { version: default_signal_version(), token: tok.clone() }).await;
|
||||
}
|
||||
|
||||
// Register presence
|
||||
let _ = transport.send_signal(&SignalMessage::RegisterPresence {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
signature: vec![],
|
||||
alias: alias.clone(),
|
||||
@@ -330,7 +351,7 @@ impl WzpEngine {
|
||||
break;
|
||||
}
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::CallRinging { call_id })) => {
|
||||
Ok(Some(SignalMessage::CallRinging { call_id, ..})) => {
|
||||
info!(call_id = %call_id, "signal: ringing");
|
||||
let mut stats = signal_state.stats.lock().unwrap();
|
||||
stats.state = crate::stats::CallState::Ringing;
|
||||
@@ -392,7 +413,11 @@ impl WzpEngine {
|
||||
}
|
||||
|
||||
/// Answer an incoming direct call.
|
||||
pub fn answer_call(&self, call_id: &str, mode: wzp_proto::CallAcceptMode) -> Result<(), anyhow::Error> {
|
||||
pub fn answer_call(
|
||||
&self,
|
||||
call_id: &str,
|
||||
mode: wzp_proto::CallAcceptMode,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let _ = self.state.command_tx.send(EngineCommand::AnswerCall {
|
||||
call_id: call_id.to_string(),
|
||||
accept_mode: mode,
|
||||
@@ -412,7 +437,9 @@ impl WzpEngine {
|
||||
/// Stores the type atomically; the recv task polls it on each packet.
|
||||
pub fn on_network_changed(&self, network_type: u8, bandwidth_kbps: u32) {
|
||||
info!(network_type, bandwidth_kbps, "on_network_changed");
|
||||
self.state.pending_network_type.store(network_type, Ordering::Release);
|
||||
self.state
|
||||
.pending_network_type
|
||||
.store(network_type, Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> CallStats {
|
||||
@@ -496,6 +523,7 @@ async fn run_call(
|
||||
let signature = kx.sign(&sign_data);
|
||||
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
@@ -508,6 +536,9 @@ async fn run_call(
|
||||
QualityProfile::CATASTROPHIC,
|
||||
],
|
||||
alias: alias.map(|s| s.to_string()),
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![CodecId::H264Baseline],
|
||||
};
|
||||
transport.send_signal(&offer).await?;
|
||||
info!("CallOffer sent, waiting for CallAnswer...");
|
||||
@@ -518,12 +549,16 @@ async fn run_call(
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?;
|
||||
|
||||
let (relay_ephemeral_pub, chosen_profile) = match answer {
|
||||
SignalMessage::CallAnswer { ephemeral_pub, chosen_profile, .. } => (ephemeral_pub, chosen_profile),
|
||||
SignalMessage::CallAnswer {
|
||||
ephemeral_pub,
|
||||
chosen_profile,
|
||||
..
|
||||
} => (ephemeral_pub, chosen_profile),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallAnswer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -574,7 +609,7 @@ async fn run_call(
|
||||
stats.auto_mode = auto_profile;
|
||||
}
|
||||
|
||||
let seq = AtomicU16::new(0);
|
||||
let seq = AtomicU32::new(0);
|
||||
let ts = AtomicU32::new(0);
|
||||
let transport_recv = transport.clone();
|
||||
|
||||
@@ -700,17 +735,15 @@ async fn run_call(
|
||||
|
||||
let source_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: MediaHeader::VERSION,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: current_profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: hdr_fec_ratio,
|
||||
stream_id: 0,
|
||||
fec_ratio: hdr_fec_ratio,
|
||||
seq: s,
|
||||
timestamp: t,
|
||||
fec_block: hdr_fec_block,
|
||||
fec_symbol: hdr_fec_symbol,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
fec_block: ((hdr_fec_symbol as u16) << 8) | (hdr_fec_block as u16),
|
||||
},
|
||||
payload: Bytes::copy_from_slice(encoded),
|
||||
quality_report: None,
|
||||
@@ -725,9 +758,7 @@ async fn run_call(
|
||||
if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 {
|
||||
warn!(
|
||||
seq = s,
|
||||
send_errors,
|
||||
frames_dropped,
|
||||
"send_media error (dropping packet): {e}"
|
||||
send_errors, frames_dropped, "send_media error (dropping packet): {e}"
|
||||
);
|
||||
last_send_error_log = Instant::now();
|
||||
}
|
||||
@@ -756,19 +787,17 @@ async fn run_call(
|
||||
let rs = seq.fetch_add(1, Ordering::Relaxed);
|
||||
let repair_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: true,
|
||||
version: MediaHeader::VERSION,
|
||||
flags: MediaHeader::FLAG_REPAIR,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: current_profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
||||
stream_id: 0,
|
||||
fec_ratio: MediaHeader::encode_fec_ratio(
|
||||
current_profile.fec_ratio,
|
||||
),
|
||||
seq: rs,
|
||||
timestamp: t,
|
||||
fec_block: block_id,
|
||||
fec_symbol: sym_idx,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
fec_block: (sym_idx << 8) | (block_id as u16),
|
||||
},
|
||||
payload: Bytes::from(repair_data),
|
||||
quality_report: None,
|
||||
@@ -820,7 +849,11 @@ async fn run_call(
|
||||
avg_total_us = avg(t_agc_us + t_opus_us + t_fec_us + t_send_us),
|
||||
"send stats"
|
||||
);
|
||||
t_agc_us = 0; t_opus_us = 0; t_fec_us = 0; t_send_us = 0; t_frames = 0;
|
||||
t_agc_us = 0;
|
||||
t_opus_us = 0;
|
||||
t_fec_us = 0;
|
||||
t_send_us = 0;
|
||||
t_frames = 0;
|
||||
last_stats_log = Instant::now();
|
||||
}
|
||||
}
|
||||
@@ -849,14 +882,11 @@ async fn run_call(
|
||||
// when a packet arrives with seq > expected_seq, the frames in
|
||||
// between are missing and we attempt to reconstruct them via
|
||||
// DRED before decoding the newly-arrived packet.
|
||||
let mut dred_decoder =
|
||||
DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
|
||||
let mut dred_parse_scratch =
|
||||
DredState::new().expect("opus_dred_alloc failed (scratch)");
|
||||
let mut last_good_dred =
|
||||
DredState::new().expect("opus_dred_alloc failed (good state)");
|
||||
let mut last_good_dred_seq: Option<u16> = None;
|
||||
let mut expected_seq: Option<u16> = None;
|
||||
let mut dred_decoder = DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
|
||||
let mut dred_parse_scratch = DredState::new().expect("opus_dred_alloc failed (scratch)");
|
||||
let mut last_good_dred = DredState::new().expect("opus_dred_alloc failed (good state)");
|
||||
let mut last_good_dred_seq: Option<u32> = None;
|
||||
let mut expected_seq: Option<u32> = None;
|
||||
let mut dred_reconstructions: u64 = 0;
|
||||
let mut classical_plc_invocations: u64 = 0;
|
||||
|
||||
@@ -877,14 +907,16 @@ async fn run_call(
|
||||
warn!(
|
||||
recv_gap_ms,
|
||||
seq = pkt.header.seq,
|
||||
is_repair = pkt.header.is_repair,
|
||||
is_repair = pkt.header.is_repair(),
|
||||
"large recv gap — possible network stall"
|
||||
);
|
||||
}
|
||||
|
||||
// Check for network transport change from ConnectivityManager
|
||||
{
|
||||
let net = state.pending_network_type.swap(PROFILE_NO_CHANGE, Ordering::Acquire);
|
||||
let net = state
|
||||
.pending_network_type
|
||||
.swap(PROFILE_NO_CHANGE, Ordering::Acquire);
|
||||
if net != PROFILE_NO_CHANGE {
|
||||
use wzp_proto::NetworkContext;
|
||||
let ctx = match net {
|
||||
@@ -916,9 +948,9 @@ async fn run_call(
|
||||
}
|
||||
}
|
||||
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
let pkt_block = pkt.header.fec_block;
|
||||
let pkt_symbol = pkt.header.fec_symbol;
|
||||
let pkt_symbol = (pkt.header.fec_block >> 8) as u16;
|
||||
let pkt_is_opus = pkt.header.codec_id.is_opus();
|
||||
|
||||
// Phase 2: Opus packets bypass RaptorQ entirely — DRED
|
||||
@@ -927,12 +959,7 @@ async fn run_call(
|
||||
// would accumulate block_id=0 duplicates that never
|
||||
// decode. Codec2 packets still feed RaptorQ.
|
||||
if !pkt_is_opus {
|
||||
let _ = fec_dec.add_symbol(
|
||||
pkt_block,
|
||||
pkt_symbol,
|
||||
is_repair,
|
||||
&pkt.payload,
|
||||
);
|
||||
let _ = fec_dec.add_symbol(pkt_block, pkt_symbol, is_repair, &pkt.payload);
|
||||
}
|
||||
|
||||
// Source packets: decode directly
|
||||
@@ -951,8 +978,12 @@ async fn run_call(
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
other => QualityProfile {
|
||||
codec: other,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
other => QualityProfile { codec: other, ..QualityProfile::GOOD },
|
||||
};
|
||||
info!(from = ?decoder.codec_id(), to = ?pkt.header.codec_id, "recv: switching decoder");
|
||||
let _ = decoder.set_profile(switch_profile);
|
||||
@@ -984,10 +1015,7 @@ async fn run_call(
|
||||
// Update DRED state from the current packet.
|
||||
match dred_decoder.parse_into(&mut dred_parse_scratch, &pkt.payload) {
|
||||
Ok(available) if available > 0 => {
|
||||
std::mem::swap(
|
||||
&mut dred_parse_scratch,
|
||||
&mut last_good_dred,
|
||||
);
|
||||
std::mem::swap(&mut dred_parse_scratch, &mut last_good_dred);
|
||||
last_good_dred_seq = Some(pkt.header.seq);
|
||||
}
|
||||
Ok(_) => {
|
||||
@@ -999,15 +1027,14 @@ async fn run_call(
|
||||
}
|
||||
|
||||
// Detect and fill gap from last-expected to this packet.
|
||||
const MAX_GAP_FRAMES: u16 = 16;
|
||||
const MAX_GAP_FRAMES: u32 = 16;
|
||||
if let Some(expected) = expected_seq {
|
||||
let gap = pkt.header.seq.wrapping_sub(expected);
|
||||
if gap > 0 && gap <= MAX_GAP_FRAMES {
|
||||
let current_profile_frame_samples =
|
||||
(48_000 * profile.frame_duration_ms as i32) / 1000;
|
||||
let available = last_good_dred.samples_available();
|
||||
let pcm_slice_len =
|
||||
current_profile_frame_samples as usize;
|
||||
let pcm_slice_len = current_profile_frame_samples as usize;
|
||||
|
||||
for gap_idx in 0..gap {
|
||||
let missing_seq = expected.wrapping_add(gap_idx);
|
||||
@@ -1026,28 +1053,24 @@ async fn run_call(
|
||||
None => -1,
|
||||
};
|
||||
|
||||
let reconstructed = if offset_samples > 0
|
||||
&& offset_samples <= available
|
||||
{
|
||||
decoder
|
||||
.reconstruct_from_dred(
|
||||
&last_good_dred,
|
||||
offset_samples,
|
||||
&mut decode_buf[..pcm_slice_len],
|
||||
)
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let reconstructed =
|
||||
if offset_samples > 0 && offset_samples <= available {
|
||||
decoder
|
||||
.reconstruct_from_dred(
|
||||
&last_good_dred,
|
||||
offset_samples,
|
||||
&mut decode_buf[..pcm_slice_len],
|
||||
)
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match reconstructed {
|
||||
Some(samples) => {
|
||||
playout_agc.process_frame(
|
||||
&mut decode_buf[..samples],
|
||||
);
|
||||
state
|
||||
.playout_ring
|
||||
.write(&decode_buf[..samples]);
|
||||
playout_agc
|
||||
.process_frame(&mut decode_buf[..samples]);
|
||||
state.playout_ring.write(&decode_buf[..samples]);
|
||||
dred_reconstructions += 1;
|
||||
frames_decoded += 1;
|
||||
}
|
||||
@@ -1144,7 +1167,10 @@ async fn run_call(
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!(frames_decoded, fec_recovered, "relay disconnected (stream ended)");
|
||||
info!(
|
||||
frames_decoded,
|
||||
fec_recovered, "relay disconnected (stream ended)"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -1162,7 +1188,10 @@ async fn run_call(
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(frames_decoded, fec_recovered, recv_errors, "recv task ended");
|
||||
info!(
|
||||
frames_decoded,
|
||||
fec_recovered, recv_errors, "recv task ended"
|
||||
);
|
||||
};
|
||||
|
||||
// Stats task — polls path quality + quinn RTT every 500ms
|
||||
@@ -1195,7 +1224,11 @@ async fn run_call(
|
||||
let signal_task = async {
|
||||
loop {
|
||||
match transport_signal.recv_signal().await {
|
||||
Ok(Some(SignalMessage::RoomUpdate { count, participants })) => {
|
||||
Ok(Some(SignalMessage::RoomUpdate {
|
||||
count,
|
||||
participants,
|
||||
..
|
||||
})) => {
|
||||
info!(count, "RoomUpdate received");
|
||||
let members: Vec<crate::stats::RoomMember> = participants
|
||||
.iter()
|
||||
@@ -1209,6 +1242,19 @@ async fn run_call(
|
||||
stats.room_participant_count = count;
|
||||
stats.room_participants = members;
|
||||
}
|
||||
Ok(Some(SignalMessage::QualityDirective {
|
||||
recommended_profile,
|
||||
reason,
|
||||
..
|
||||
})) => {
|
||||
let idx = profile_to_index(&recommended_profile);
|
||||
info!(
|
||||
codec = ?recommended_profile.codec,
|
||||
reason = reason.as_deref().unwrap_or(""),
|
||||
"relay quality directive: switching profile"
|
||||
);
|
||||
pending_profile_recv.store(idx, Ordering::Release);
|
||||
}
|
||||
Ok(Some(msg)) => {
|
||||
info!("signal received: {:?}", std::mem::discriminant(&msg));
|
||||
}
|
||||
@@ -1238,7 +1284,9 @@ async fn run_call(
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(500),
|
||||
transport.connection().closed(),
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => info!("QUIC connection closed cleanly"),
|
||||
Err(_) => info!("QUIC close timed out (relay may not have ack'd)"),
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
use std::panic;
|
||||
use std::sync::Once;
|
||||
|
||||
use jni::JNIEnv;
|
||||
use jni::objects::{JClass, JObject, JString};
|
||||
use jni::sys::{jboolean, jint, jlong, jstring};
|
||||
use jni::JNIEnv;
|
||||
use tracing::{error, info};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
@@ -26,19 +26,21 @@ const PROFILE_AUTO: jint = 7;
|
||||
|
||||
fn profile_from_int(value: jint) -> QualityProfile {
|
||||
match value {
|
||||
0 => QualityProfile::GOOD, // Opus 24k
|
||||
1 => QualityProfile::DEGRADED, // Opus 6k
|
||||
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
|
||||
3 => QualityProfile { // Codec2 3.2k
|
||||
0 => QualityProfile::GOOD, // Opus 24k
|
||||
1 => QualityProfile::DEGRADED, // Opus 6k
|
||||
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
|
||||
3 => QualityProfile {
|
||||
// Codec2 3.2k
|
||||
codec: wzp_proto::CodecId::Codec2_3200,
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
4 => QualityProfile::STUDIO_32K, // Opus 32k
|
||||
5 => QualityProfile::STUDIO_48K, // Opus 48k
|
||||
6 => QualityProfile::STUDIO_64K, // Opus 64k
|
||||
_ => QualityProfile::GOOD, // auto falls back to GOOD
|
||||
4 => QualityProfile::STUDIO_32K, // Opus 32k
|
||||
5 => QualityProfile::STUDIO_48K, // Opus 48k
|
||||
6 => QualityProfile::STUDIO_64K, // Opus 64k
|
||||
_ => QualityProfile::GOOD, // auto falls back to GOOD
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,25 +50,33 @@ static INIT_LOGGING: Once = Once::new();
|
||||
/// Safe to call multiple times — only the first call takes effect.
|
||||
fn init_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
// Wrap in catch_unwind — sharded_slab allocation inside
|
||||
// tracing_subscriber::registry() can crash on some Android
|
||||
// devices if scudo malloc fails during early initialization.
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
if let Ok(layer) = tracing_android::layer("wzp_android") {
|
||||
// Filter: INFO for our crates, WARN for everything else.
|
||||
// The jni crate emits VERBOSE logs for every method lookup
|
||||
// (~10 lines per JNI call, 100+ calls/sec) which floods logcat
|
||||
// and causes the system to kill the app.
|
||||
let filter = EnvFilter::new("warn,wzp_android=info,wzp_proto=info,wzp_transport=info,wzp_codec=info,wzp_fec=info,wzp_crypto=info");
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(layer)
|
||||
.with(filter)
|
||||
.try_init();
|
||||
}
|
||||
});
|
||||
#[cfg(target_os = "android")]
|
||||
{
|
||||
// Wrap in catch_unwind — sharded_slab allocation inside
|
||||
// tracing_subscriber::registry() can crash on some Android
|
||||
// devices if scudo malloc fails during early initialization.
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
if let Ok(layer) = tracing_android::layer("wzp_android") {
|
||||
// Filter: INFO for our crates, WARN for everything else.
|
||||
// The jni crate emits VERBOSE logs for every method lookup
|
||||
// (~10 lines per JNI call, 100+ calls/sec) which floods logcat
|
||||
// and causes the system to kill the app.
|
||||
let filter = EnvFilter::new("warn,wzp_android=info,wzp_proto=info,wzp_transport=info,wzp_codec=info,wzp_fec=info,wzp_crypto=info");
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(layer)
|
||||
.with(filter)
|
||||
.try_init();
|
||||
}
|
||||
});
|
||||
}
|
||||
#[cfg(not(target_os = "android"))]
|
||||
{
|
||||
// On non-Android targets tracing-android is unavailable.
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -101,11 +111,26 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
||||
profile_j: jint,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
||||
let room: String = env.get_string(&room_j).map(|s| s.into()).unwrap_or_default();
|
||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
||||
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default();
|
||||
let relay_addr: String = env
|
||||
.get_string(&relay_addr_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let room: String = env
|
||||
.get_string(&room_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let seed_hex: String = env
|
||||
.get_string(&seed_hex_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let token: String = env
|
||||
.get_string(&token_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let alias: String = env
|
||||
.get_string(&alias_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
|
||||
@@ -128,7 +153,11 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
||||
auto_profile: profile_j == PROFILE_AUTO,
|
||||
relay_addr,
|
||||
room,
|
||||
auth_token: if token.is_empty() { Vec::new() } else { token.into_bytes() },
|
||||
auth_token: if token.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
token.into_bytes()
|
||||
},
|
||||
identity_seed,
|
||||
alias: if alias.is_empty() { None } else { Some(alias) },
|
||||
};
|
||||
@@ -241,7 +270,8 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeOnNetworkChang
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
h.engine.on_network_changed(network_type as u8, bandwidth_kbps as u32);
|
||||
h.engine
|
||||
.on_network_changed(network_type as u8, bandwidth_kbps as u32);
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -307,13 +337,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudioDire
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut());
|
||||
let ptr = env
|
||||
.get_direct_buffer_address(&buffer)
|
||||
.unwrap_or(std::ptr::null_mut());
|
||||
if ptr.is_null() || sample_count <= 0 {
|
||||
return 0;
|
||||
}
|
||||
let samples = unsafe {
|
||||
std::slice::from_raw_parts(ptr as *const i16, sample_count as usize)
|
||||
};
|
||||
let samples =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const i16, sample_count as usize) };
|
||||
h.engine.write_audio(samples) as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
@@ -332,13 +363,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudioDirec
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut());
|
||||
let ptr = env
|
||||
.get_direct_buffer_address(&buffer)
|
||||
.unwrap_or(std::ptr::null_mut());
|
||||
if ptr.is_null() || max_samples <= 0 {
|
||||
return 0;
|
||||
}
|
||||
let samples = unsafe {
|
||||
std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize)
|
||||
};
|
||||
let samples =
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize) };
|
||||
h.engine.read_audio(samples) as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
@@ -367,7 +399,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePingRelay<'a>(
|
||||
) -> jstring {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let relay: String = env.get_string(&relay_j).map(|s| s.into()).unwrap_or_default();
|
||||
let relay: String = env
|
||||
.get_string(&relay_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
match h.engine.ping_relay(&relay) {
|
||||
Ok(json) => Some(json),
|
||||
Err(_) => None,
|
||||
@@ -399,10 +434,22 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
||||
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default();
|
||||
let relay_addr: String = env
|
||||
.get_string(&relay_addr_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let seed_hex: String = env
|
||||
.get_string(&seed_hex_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let token: String = env
|
||||
.get_string(&token_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let alias: String = env
|
||||
.get_string(&alias_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
|
||||
h.engine.start_signaling(
|
||||
&relay_addr,
|
||||
@@ -414,8 +461,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => { error!("start_signaling failed: {e}"); -1 }
|
||||
Err(_) => { error!("start_signaling panicked"); -1 }
|
||||
Ok(Err(e)) => {
|
||||
error!("start_signaling failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("start_signaling panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,14 +483,23 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePlaceCall<'a>(
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let target: String = env.get_string(&target_fp_j).map(|s| s.into()).unwrap_or_default();
|
||||
let target: String = env
|
||||
.get_string(&target_fp_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
h.engine.place_call(&target)
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => { error!("place_call failed: {e}"); -1 }
|
||||
Err(_) => { error!("place_call panicked"); -1 }
|
||||
Ok(Err(e)) => {
|
||||
error!("place_call failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("place_call panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,7 +515,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let call_id: String = env.get_string(&call_id_j).map(|s| s.into()).unwrap_or_default();
|
||||
let call_id: String = env
|
||||
.get_string(&call_id_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let accept_mode = match mode {
|
||||
0 => wzp_proto::CallAcceptMode::Reject,
|
||||
1 => wzp_proto::CallAcceptMode::AcceptTrusted,
|
||||
@@ -464,7 +529,13 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => { error!("answer_call failed: {e}"); -1 }
|
||||
Err(_) => { error!("answer_call panicked"); -1 }
|
||||
Ok(Err(e)) => {
|
||||
error!("answer_call failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("answer_call panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,6 @@ pub mod audio_android;
|
||||
pub mod audio_ring;
|
||||
pub mod commands;
|
||||
pub mod engine;
|
||||
pub mod jni_bridge;
|
||||
pub mod pipeline;
|
||||
pub mod stats;
|
||||
pub mod jni_bridge;
|
||||
|
||||
@@ -9,8 +9,8 @@ use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder, AutoGainControl, EchoCanceller
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::traits::QualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::{MediaPacket, QualityProfile};
|
||||
|
||||
use crate::audio_android::FRAME_SAMPLES;
|
||||
@@ -58,14 +58,12 @@ pub struct Pipeline {
|
||||
impl Pipeline {
|
||||
/// Create a new pipeline configured for the given quality profile.
|
||||
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
|
||||
let encoder = AdaptiveEncoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||
let decoder = AdaptiveDecoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||
let fec_encoder =
|
||||
RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||
let fec_decoder =
|
||||
RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||
let encoder =
|
||||
AdaptiveEncoder::new(profile).map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||
let decoder =
|
||||
AdaptiveDecoder::new(profile).map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||
let fec_encoder = RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||
let fec_decoder = RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||
let jitter_buffer = JitterBuffer::new(10, 250, 3);
|
||||
let quality_ctrl = AdaptiveQualityController::new();
|
||||
|
||||
@@ -136,11 +134,11 @@ impl Pipeline {
|
||||
pub fn feed_packet(&mut self, packet: MediaPacket) {
|
||||
// Feed FEC symbols if present
|
||||
let header = &packet.header;
|
||||
if header.fec_block != 0 || header.fec_symbol != 0 {
|
||||
let is_repair = header.is_repair;
|
||||
if header.fec_block != 0 {
|
||||
let is_repair = header.is_repair();
|
||||
if let Err(e) = self.fec_decoder.add_symbol(
|
||||
header.fec_block,
|
||||
header.fec_symbol,
|
||||
header.fec_block >> 8,
|
||||
is_repair,
|
||||
&packet.payload,
|
||||
) {
|
||||
@@ -211,10 +209,7 @@ impl Pipeline {
|
||||
///
|
||||
/// Returns a new profile if a tier transition occurred.
|
||||
#[allow(unused)]
|
||||
pub fn observe_quality(
|
||||
&mut self,
|
||||
report: &wzp_proto::QualityReport,
|
||||
) -> Option<QualityProfile> {
|
||||
pub fn observe_quality(&mut self, report: &wzp_proto::QualityReport) -> Option<QualityProfile> {
|
||||
let new_profile = self.quality_ctrl.observe(report);
|
||||
if let Some(ref profile) = new_profile {
|
||||
if let Err(e) = self.encoder.set_profile(*profile) {
|
||||
|
||||
@@ -12,6 +12,7 @@ wzp-codec = { workspace = true }
|
||||
wzp-fec = { workspace = true }
|
||||
wzp-crypto = { workspace = true }
|
||||
wzp-transport = { workspace = true }
|
||||
wzp-video = { path = "../wzp-video" }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@@ -21,6 +22,9 @@ anyhow = "1"
|
||||
serde = { workspace = true }
|
||||
serde_json = "1"
|
||||
chrono = "0.4"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
ratatui = "0.29"
|
||||
crossterm = "0.28"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
cpal = { version = "0.15", optional = true }
|
||||
libc = "0.2"
|
||||
@@ -30,6 +34,8 @@ libc = "0.2"
|
||||
# through the WAN reflex addr (which many consumer NATs, including
|
||||
# MikroTik's default masquerade, don't support).
|
||||
if-addrs = "0.13"
|
||||
rand = { workspace = true }
|
||||
socket2 = "0.5"
|
||||
|
||||
# coreaudio-rs is Apple-framework-only; gate it to macOS so enabling
|
||||
# the `vpio` feature from a non-macOS target builds cleanly instead of
|
||||
@@ -99,6 +105,10 @@ linux-aec = ["dep:webrtc-audio-processing"]
|
||||
name = "wzp-client"
|
||||
path = "src/cli.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-analyzer"
|
||||
path = "src/analyzer.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-bench"
|
||||
path = "src/bench_cli.rs"
|
||||
|
||||
1309
crates/wzp-client/src/analyzer.rs
Normal file
1309
crates/wzp-client/src/analyzer.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,10 +6,10 @@
|
||||
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
|
||||
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||
use tracing::{info, warn};
|
||||
@@ -78,7 +78,10 @@ impl AudioCapture {
|
||||
return;
|
||||
}
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[audio] capture callback: {} f32 samples", data.len());
|
||||
eprintln!(
|
||||
"[audio] capture callback: {} f32 samples",
|
||||
data.len()
|
||||
);
|
||||
}
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in data.chunks(FRAME_SAMPLES) {
|
||||
@@ -103,7 +106,10 @@ impl AudioCapture {
|
||||
return;
|
||||
}
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[audio] capture callback: {} i16 samples", data.len());
|
||||
eprintln!(
|
||||
"[audio] capture callback: {} i16 samples",
|
||||
data.len()
|
||||
);
|
||||
}
|
||||
ring.write(data);
|
||||
},
|
||||
|
||||
@@ -54,13 +54,13 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||
use tracing::{info, warn};
|
||||
use webrtc_audio_processing::{
|
||||
Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig,
|
||||
NoiseSuppression, NoiseSuppressionLevel, Processor, NUM_SAMPLES_PER_FRAME,
|
||||
NUM_SAMPLES_PER_FRAME, NoiseSuppression, NoiseSuppressionLevel, Processor,
|
||||
};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
@@ -97,8 +97,8 @@ fn get_or_init_processor() -> anyhow::Result<Arc<Mutex<Processor>>> {
|
||||
num_render_channels: APM_NUM_CHANNELS as i32,
|
||||
..Default::default()
|
||||
};
|
||||
let mut processor = Processor::new(&init_config)
|
||||
.map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
|
||||
let mut processor =
|
||||
Processor::new(&init_config).map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
|
||||
|
||||
let config = Config {
|
||||
echo_cancellation: Some(EchoCancellation {
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
//! to the speaker, so it can cancel the echo from the mic signal internally.
|
||||
//! This is the same engine FaceTime and other Apple apps use.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
|
||||
use anyhow::Context;
|
||||
use coreaudio::audio_unit::audio_format::LinearPcmFlags;
|
||||
@@ -28,6 +28,60 @@ pub struct VpioAudio {
|
||||
playout_ring: Arc<AudioRing>,
|
||||
_audio_unit: AudioUnit,
|
||||
running: Arc<AtomicBool>,
|
||||
stats: Arc<VpioStats>,
|
||||
}
|
||||
|
||||
/// Render/capture counters for diagnosing macOS VoiceProcessingIO.
|
||||
///
|
||||
/// These are atomics because CoreAudio callbacks run on realtime audio
|
||||
/// threads. The Tauri engine polls snapshots from a normal async task and
|
||||
/// emits them to the call debug log.
|
||||
#[derive(Default)]
|
||||
pub struct VpioStats {
|
||||
capture_callbacks: AtomicU64,
|
||||
capture_samples: AtomicU64,
|
||||
render_callbacks: AtomicU64,
|
||||
render_requested_samples: AtomicU64,
|
||||
render_read_samples: AtomicU64,
|
||||
render_underrun_callbacks: AtomicU64,
|
||||
render_nonzero_callbacks: AtomicU64,
|
||||
render_last_requested: AtomicU64,
|
||||
render_last_read: AtomicU64,
|
||||
render_last_rms: AtomicU64,
|
||||
render_last_ring_available: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct VpioStatsSnapshot {
|
||||
pub capture_callbacks: u64,
|
||||
pub capture_samples: u64,
|
||||
pub render_callbacks: u64,
|
||||
pub render_requested_samples: u64,
|
||||
pub render_read_samples: u64,
|
||||
pub render_underrun_callbacks: u64,
|
||||
pub render_nonzero_callbacks: u64,
|
||||
pub render_last_requested: u64,
|
||||
pub render_last_read: u64,
|
||||
pub render_last_rms: u64,
|
||||
pub render_last_ring_available: u64,
|
||||
}
|
||||
|
||||
impl VpioStats {
|
||||
pub fn snapshot(&self) -> VpioStatsSnapshot {
|
||||
VpioStatsSnapshot {
|
||||
capture_callbacks: self.capture_callbacks.load(Ordering::Relaxed),
|
||||
capture_samples: self.capture_samples.load(Ordering::Relaxed),
|
||||
render_callbacks: self.render_callbacks.load(Ordering::Relaxed),
|
||||
render_requested_samples: self.render_requested_samples.load(Ordering::Relaxed),
|
||||
render_read_samples: self.render_read_samples.load(Ordering::Relaxed),
|
||||
render_underrun_callbacks: self.render_underrun_callbacks.load(Ordering::Relaxed),
|
||||
render_nonzero_callbacks: self.render_nonzero_callbacks.load(Ordering::Relaxed),
|
||||
render_last_requested: self.render_last_requested.load(Ordering::Relaxed),
|
||||
render_last_read: self.render_last_read.load(Ordering::Relaxed),
|
||||
render_last_rms: self.render_last_rms.load(Ordering::Relaxed),
|
||||
render_last_ring_available: self.render_last_ring_available.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VpioAudio {
|
||||
@@ -36,6 +90,7 @@ impl VpioAudio {
|
||||
let capture_ring = Arc::new(AudioRing::new());
|
||||
let playout_ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let stats = Arc::new(VpioStats::default());
|
||||
|
||||
let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
|
||||
.context("failed to create VoiceProcessingIO audio unit")?;
|
||||
@@ -98,6 +153,7 @@ impl VpioAudio {
|
||||
// Set up input callback (mic capture with AEC applied)
|
||||
let cap_ring = capture_ring.clone();
|
||||
let cap_running = running.clone();
|
||||
let cap_stats = stats.clone();
|
||||
let logged = Arc::new(AtomicBool::new(false));
|
||||
au.set_input_callback(
|
||||
move |args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
@@ -106,6 +162,10 @@ impl VpioAudio {
|
||||
}
|
||||
let mut buffers = args.data.channels();
|
||||
if let Some(ch) = buffers.next() {
|
||||
cap_stats.capture_callbacks.fetch_add(1, Ordering::Relaxed);
|
||||
cap_stats
|
||||
.capture_samples
|
||||
.fetch_add(ch.len() as u64, Ordering::Relaxed);
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
|
||||
}
|
||||
@@ -125,28 +185,80 @@ impl VpioAudio {
|
||||
|
||||
// Set up output callback (speaker playback — AEC uses this as reference)
|
||||
let play_ring = playout_ring.clone();
|
||||
let render_stats = stats.clone();
|
||||
let logged_render = Arc::new(AtomicBool::new(false));
|
||||
au.set_render_callback(
|
||||
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
let mut buffers = args.data.channels_mut();
|
||||
if let Some(ch) = buffers.next() {
|
||||
render_stats
|
||||
.render_callbacks
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
render_stats
|
||||
.render_requested_samples
|
||||
.fetch_add(ch.len() as u64, Ordering::Relaxed);
|
||||
render_stats
|
||||
.render_last_requested
|
||||
.store(ch.len() as u64, Ordering::Relaxed);
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
let mut total_read = 0usize;
|
||||
let mut sum_sq = 0u64;
|
||||
let ring_available = play_ring.available();
|
||||
for chunk in ch.chunks_mut(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
let read = play_ring.read(&mut tmp[..n]);
|
||||
total_read += read;
|
||||
for i in 0..read {
|
||||
let s = tmp[i] as i64;
|
||||
sum_sq = sum_sq.saturating_add((s * s) as u64);
|
||||
chunk[i] = tmp[i] as f32 / i16::MAX as f32;
|
||||
}
|
||||
for i in read..n {
|
||||
chunk[i] = 0.0;
|
||||
}
|
||||
}
|
||||
render_stats
|
||||
.render_read_samples
|
||||
.fetch_add(total_read as u64, Ordering::Relaxed);
|
||||
render_stats
|
||||
.render_last_read
|
||||
.store(total_read as u64, Ordering::Relaxed);
|
||||
render_stats
|
||||
.render_last_ring_available
|
||||
.store(ring_available as u64, Ordering::Relaxed);
|
||||
if total_read == 0 {
|
||||
render_stats
|
||||
.render_underrun_callbacks
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
let rms = if total_read > 0 {
|
||||
((sum_sq as f64 / total_read as f64).sqrt()) as u64
|
||||
} else {
|
||||
0
|
||||
};
|
||||
render_stats.render_last_rms.store(rms, Ordering::Relaxed);
|
||||
if rms > 0 {
|
||||
render_stats
|
||||
.render_nonzero_callbacks
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
if !logged_render.swap(true, Ordering::Relaxed) {
|
||||
eprintln!(
|
||||
"[vpio] render callback: {} f32 samples, ring_available={}, ring_read={}, rms={}",
|
||||
ch.len(),
|
||||
ring_available,
|
||||
total_read,
|
||||
rms
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.context("failed to set render callback")?;
|
||||
|
||||
au.initialize().context("failed to initialize VoiceProcessingIO")?;
|
||||
au.initialize()
|
||||
.context("failed to initialize VoiceProcessingIO")?;
|
||||
au.start().context("failed to start VoiceProcessingIO")?;
|
||||
|
||||
info!("VoiceProcessingIO started (OS-level AEC enabled)");
|
||||
@@ -156,6 +268,7 @@ impl VpioAudio {
|
||||
playout_ring,
|
||||
_audio_unit: au,
|
||||
running,
|
||||
stats,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -167,6 +280,10 @@ impl VpioAudio {
|
||||
&self.playout_ring
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> Arc<VpioStats> {
|
||||
self.stats.clone()
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
@@ -15,24 +15,24 @@
|
||||
//! `wzp-client`'s lib.rs can transparently re-export either one as
|
||||
//! `AudioCapture`.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use tracing::{info, warn};
|
||||
use windows::core::{Interface, GUID};
|
||||
use windows::Win32::Foundation::{CloseHandle, BOOL, WAIT_OBJECT_0};
|
||||
use windows::Win32::Foundation::{BOOL, CloseHandle, WAIT_OBJECT_0};
|
||||
use windows::Win32::Media::Audio::{
|
||||
eCapture, eCommunications, AudioCategory_Communications, AudioClientProperties,
|
||||
IAudioCaptureClient, IAudioClient, IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator,
|
||||
AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM,
|
||||
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY, WAVEFORMATEX,
|
||||
WAVE_FORMAT_PCM,
|
||||
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY,
|
||||
AudioCategory_Communications, AudioClientProperties, IAudioCaptureClient, IAudioClient,
|
||||
IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator, WAVE_FORMAT_PCM, WAVEFORMATEX,
|
||||
eCapture, eCommunications,
|
||||
};
|
||||
use windows::Win32::System::Com::{
|
||||
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
|
||||
CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
|
||||
};
|
||||
use windows::Win32::System::Threading::{CreateEventW, WaitForSingleObject, INFINITE};
|
||||
use windows::Win32::System::Threading::{CreateEventW, INFINITE, WaitForSingleObject};
|
||||
use windows::core::{GUID, Interface};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
@@ -138,9 +138,8 @@ unsafe fn capture_thread_main(
|
||||
}
|
||||
let _com_guard = ComGuard;
|
||||
|
||||
let enumerator: IMMDeviceEnumerator =
|
||||
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
|
||||
.context("CoCreateInstance(MMDeviceEnumerator) failed")?;
|
||||
let enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
|
||||
.context("CoCreateInstance(MMDeviceEnumerator) failed")?;
|
||||
|
||||
// eCommunications role (not eConsole) — this picks the device the user
|
||||
// has designated for communications in Sound Settings. It's the one
|
||||
@@ -206,12 +205,13 @@ unsafe fn capture_thread_main(
|
||||
&wave_format,
|
||||
Some(&GUID::zeroed()),
|
||||
)
|
||||
.context("IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16")?;
|
||||
.context(
|
||||
"IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16",
|
||||
)?;
|
||||
|
||||
// Event-driven capture: Windows signals this handle each time a new
|
||||
// audio packet is available. We wait on it from the loop below.
|
||||
let event = CreateEventW(None, false, false, None)
|
||||
.context("CreateEventW failed")?;
|
||||
let event = CreateEventW(None, false, false, None).context("CreateEventW failed")?;
|
||||
audio_client
|
||||
.SetEventHandle(event)
|
||||
.context("SetEventHandle failed")?;
|
||||
@@ -285,10 +285,8 @@ unsafe fn capture_thread_main(
|
||||
// Because we asked for 48 kHz mono i16, each frame is
|
||||
// exactly one i16. Windows's AUTOCONVERTPCM handles the
|
||||
// conversion from whatever the engine mix format is.
|
||||
let samples = std::slice::from_raw_parts(
|
||||
buffer_ptr as *const i16,
|
||||
num_frames as usize,
|
||||
);
|
||||
let samples =
|
||||
std::slice::from_raw_parts(buffer_ptr as *const i16, num_frames as usize);
|
||||
ring.write(samples);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ use std::time::{Duration, Instant};
|
||||
|
||||
use wzp_crypto::ChaChaSession;
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
||||
|
||||
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
|
||||
@@ -151,7 +151,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
let mut total_repair_bytes = 0usize;
|
||||
|
||||
for block_idx in 0..num_blocks {
|
||||
let block_id = (block_idx % 256) as u8;
|
||||
let block_id = (block_idx % 65536) as u16;
|
||||
|
||||
// Create fresh encoder and decoder for each block
|
||||
let mut fec_enc = RaptorQFecEncoder::new(frames_per_block, 256);
|
||||
@@ -170,7 +170,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
|
||||
// Collect all symbols: source + repair
|
||||
struct Symbol {
|
||||
index: u8,
|
||||
index: u16,
|
||||
is_repair: bool,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
@@ -180,7 +180,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
// For add_symbol we need to provide the raw data; the decoder pads internally
|
||||
total_source_bytes += sym.len();
|
||||
all_symbols.push(Symbol {
|
||||
index: i as u8,
|
||||
index: i as u16,
|
||||
is_repair: false,
|
||||
data: sym.clone(),
|
||||
});
|
||||
@@ -201,9 +201,13 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
// Deterministic shuffle for reproducibility using a simple seed
|
||||
// We use a basic Fisher-Yates with a fixed-per-block seed
|
||||
let mut indices: Vec<usize> = (0..all_symbols.len()).collect();
|
||||
let mut seed = (block_idx as u64).wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let mut seed = (block_idx as u64)
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1);
|
||||
for i in (1..indices.len()).rev() {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
seed = seed
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
let j = (seed >> 33) as usize % (i + 1);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
@@ -259,17 +263,36 @@ pub fn bench_encrypt_decrypt() -> CryptoResult {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let header = b"bench-header";
|
||||
// Build valid v2 MediaHeader bytes — encrypt/decrypt now derive nonces from
|
||||
// header.seq and require a parseable MediaHeader (WIRE_SIZE bytes minimum).
|
||||
use wzp_proto::packet::MediaHeader;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
let mut total_bytes: usize = 0;
|
||||
|
||||
let start = Instant::now();
|
||||
for payload in &payloads {
|
||||
for (i, payload) in payloads.iter().enumerate() {
|
||||
let hdr = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: i as u32,
|
||||
timestamp: (i as u32).wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
hdr.write_to(&mut header_bytes);
|
||||
|
||||
let mut ciphertext = Vec::with_capacity(payload.len() + 16);
|
||||
encryptor.encrypt(header, payload, &mut ciphertext).unwrap();
|
||||
encryptor
|
||||
.encrypt(&header_bytes, payload, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut plaintext = Vec::with_capacity(payload.len());
|
||||
decryptor
|
||||
.decrypt(header, &ciphertext, &mut plaintext)
|
||||
.decrypt(&header_bytes, &ciphertext, &mut plaintext)
|
||||
.unwrap();
|
||||
|
||||
total_bytes += payload.len();
|
||||
|
||||
@@ -24,8 +24,14 @@ fn run_codec() {
|
||||
print_header("Codec Roundtrip (Opus 24kbps)");
|
||||
let r = bench::bench_codec_roundtrip();
|
||||
print_row("Frames", &format!("{}", r.frames));
|
||||
print_row("Encode total", &format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0));
|
||||
print_row("Decode total", &format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Encode total",
|
||||
&format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Decode total",
|
||||
&format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us));
|
||||
print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us));
|
||||
print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec));
|
||||
@@ -41,7 +47,10 @@ fn run_fec(loss_pct: f32) {
|
||||
print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct));
|
||||
print_row("Source bytes", &format!("{}", r.total_source_bytes));
|
||||
print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes));
|
||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Total time",
|
||||
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_footer();
|
||||
}
|
||||
|
||||
@@ -49,7 +58,10 @@ fn run_crypto() {
|
||||
print_header("Crypto (ChaCha20-Poly1305)");
|
||||
let r = bench::bench_encrypt_decrypt();
|
||||
print_row("Packets", &format!("{}", r.packets));
|
||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Total time",
|
||||
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec));
|
||||
print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec));
|
||||
print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us));
|
||||
@@ -60,9 +72,18 @@ fn run_pipeline() {
|
||||
print_header("Full Pipeline (E2E)");
|
||||
let r = bench::bench_full_pipeline();
|
||||
print_row("Frames", &format!("{}", r.frames));
|
||||
print_row("Encode pipeline", &format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0));
|
||||
print_row("Decode pipeline", &format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0));
|
||||
print_row("Avg E2E latency", &format!("{:.1} us/frame", r.avg_e2e_latency_us));
|
||||
print_row(
|
||||
"Encode pipeline",
|
||||
&format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Decode pipeline",
|
||||
&format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Avg E2E latency",
|
||||
&format!("{:.1} us/frame", r.avg_e2e_latency_us),
|
||||
);
|
||||
print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in));
|
||||
print_row("Wire out", &format!("{} bytes", r.wire_bytes_out));
|
||||
print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio));
|
||||
|
||||
347
crates/wzp-client/src/birthday.rs
Normal file
347
crates/wzp-client/src/birthday.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! Birthday attack for hard NAT traversal.
|
||||
//!
|
||||
//! When both peers are behind symmetric NATs with random port
|
||||
//! allocation, standard hole-punching fails because neither side
|
||||
//! can predict the other's external port. This module implements
|
||||
//! the birthday-paradox approach:
|
||||
//!
|
||||
//! 1. **Acceptor** opens N sockets, STUN-probes each to learn
|
||||
//! their external ports, reports them to the Dialer.
|
||||
//! 2. **Dialer** sprays QUIC connect attempts to the Acceptor's
|
||||
//! reported ports + random ports on the Acceptor's IP.
|
||||
//! 3. Birthday paradox: with N=64 ports and M=256 probes across
|
||||
//! 65536 ports, collision probability is high.
|
||||
//!
|
||||
//! In practice, the Acceptor's STUN-probed ports are known
|
||||
//! exactly (not random), so the Dialer targets them first —
|
||||
//! making this more like "spray-and-pray with a hit list" than
|
||||
//! a pure birthday attack.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::stun;
|
||||
|
||||
/// Configuration for the birthday attack.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BirthdayConfig {
|
||||
/// Number of sockets the Acceptor opens (default: 32).
|
||||
/// Each socket gets STUN-probed to learn its external port.
|
||||
/// More = higher chance of collision, but more resource usage.
|
||||
pub acceptor_ports: u16,
|
||||
/// Number of QUIC connect attempts the Dialer makes (default: 128).
|
||||
/// Spread across the Acceptor's known ports + random ports.
|
||||
pub dialer_probes: u16,
|
||||
/// Rate limit: ms between consecutive probes (default: 20ms = 50/s).
|
||||
pub probe_interval_ms: u16,
|
||||
/// Overall timeout for the birthday attack phase.
|
||||
pub timeout: Duration,
|
||||
/// STUN config for probing external ports.
|
||||
pub stun_config: stun::StunConfig,
|
||||
}
|
||||
|
||||
impl Default for BirthdayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
acceptor_ports: 32,
|
||||
dialer_probes: 128,
|
||||
probe_interval_ms: 20,
|
||||
timeout: Duration::from_secs(8),
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec!["stun.l.google.com:19302".into()],
|
||||
timeout: Duration::from_secs(2),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of the Acceptor's port-opening phase.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct AcceptorPorts {
|
||||
/// External IP (from STUN).
|
||||
pub external_ip: Option<Ipv4Addr>,
|
||||
/// List of (local_port, external_port) for each opened socket.
|
||||
pub ports: Vec<PortMapping>,
|
||||
/// How many sockets we attempted to open.
|
||||
pub attempted: u16,
|
||||
/// How many STUN probes succeeded.
|
||||
pub succeeded: u16,
|
||||
}
|
||||
|
||||
/// A single socket's local↔external port mapping.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct PortMapping {
|
||||
pub local_port: u16,
|
||||
pub external_port: u16,
|
||||
}
|
||||
|
||||
/// Open N sockets and STUN-probe each to discover external ports.
|
||||
///
|
||||
/// Returns the set of known external ports that the Dialer should
|
||||
/// target. Each socket stays open (bound) so the NAT mapping
|
||||
/// remains active until the returned `PortGuard` is dropped.
|
||||
///
|
||||
/// The sockets are returned so the caller can keep them alive
|
||||
/// during the attack. Dropping them closes the NAT pinholes.
|
||||
pub async fn open_acceptor_ports(
|
||||
config: &BirthdayConfig,
|
||||
) -> (AcceptorPorts, Vec<tokio::net::UdpSocket>) {
|
||||
let mut sockets = Vec::new();
|
||||
let mut mappings = Vec::new();
|
||||
let mut external_ip: Option<Ipv4Addr> = None;
|
||||
let mut succeeded: u16 = 0;
|
||||
|
||||
let stun_server = match config.stun_config.servers.first() {
|
||||
Some(s) => match stun::resolve_stun_server(s).await {
|
||||
Ok(a) => Some(a),
|
||||
Err(_) => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
for _ in 0..config.acceptor_ports {
|
||||
// Bind to random port
|
||||
let sock = match tokio::net::UdpSocket::bind("0.0.0.0:0").await {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let local_port = match sock.local_addr() {
|
||||
Ok(a) => a.port(),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// STUN probe to learn external port
|
||||
if let Some(stun_addr) = stun_server {
|
||||
match stun::stun_reflect(&sock, stun_addr, config.stun_config.timeout).await {
|
||||
Ok(ext_addr) => {
|
||||
if external_ip.is_none() {
|
||||
if let std::net::IpAddr::V4(ip) = ext_addr.ip() {
|
||||
external_ip = Some(ip);
|
||||
}
|
||||
}
|
||||
mappings.push(PortMapping {
|
||||
local_port,
|
||||
external_port: ext_addr.port(),
|
||||
});
|
||||
succeeded += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(local_port, error = %e, "birthday: STUN probe failed for socket");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sockets.push(sock);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
attempted = config.acceptor_ports,
|
||||
succeeded,
|
||||
external_ip = ?external_ip,
|
||||
"birthday: acceptor ports opened"
|
||||
);
|
||||
|
||||
let result = AcceptorPorts {
|
||||
external_ip,
|
||||
ports: mappings,
|
||||
attempted: config.acceptor_ports,
|
||||
succeeded,
|
||||
};
|
||||
|
||||
(result, sockets)
|
||||
}
|
||||
|
||||
/// Generate the list of target addresses for the Dialer to spray.
|
||||
///
|
||||
/// Priority order:
|
||||
/// 1. Acceptor's known external ports (from STUN probes) — highest hit rate
|
||||
/// 2. Random ports on the Acceptor's IP — birthday paradox fill
|
||||
pub fn generate_dialer_targets(
|
||||
acceptor_ip: Ipv4Addr,
|
||||
known_ports: &[u16],
|
||||
total_probes: u16,
|
||||
) -> Vec<SocketAddr> {
|
||||
let mut targets = Vec::with_capacity(total_probes as usize);
|
||||
|
||||
// First: all known ports (guaranteed targets)
|
||||
for &port in known_ports {
|
||||
targets.push(SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port));
|
||||
}
|
||||
|
||||
// Fill remaining with random ports (birthday attack)
|
||||
let remaining = total_probes.saturating_sub(known_ports.len() as u16);
|
||||
if remaining > 0 {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..remaining {
|
||||
let port = rng.gen_range(1024..=65535u16);
|
||||
let addr = SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port);
|
||||
if !targets.contains(&addr) {
|
||||
targets.push(addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targets
|
||||
}
|
||||
|
||||
/// Run the Dialer side of the birthday attack.
|
||||
///
|
||||
/// Sprays QUIC connection attempts at the target addresses.
|
||||
/// Returns the first successful connection, or None on timeout.
|
||||
pub async fn spray_dialer(
|
||||
endpoint: &wzp_transport::Endpoint,
|
||||
targets: &[SocketAddr],
|
||||
call_sni: &str,
|
||||
probe_interval: Duration,
|
||||
timeout: Duration,
|
||||
) -> Option<wzp_transport::QuinnTransport> {
|
||||
let start = Instant::now();
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
tracing::info!(
|
||||
target_count = targets.len(),
|
||||
interval_ms = probe_interval.as_millis(),
|
||||
timeout_s = timeout.as_secs(),
|
||||
"birthday: dialer starting spray"
|
||||
);
|
||||
|
||||
// Spray connects with rate limiting
|
||||
for (idx, &target) in targets.iter().enumerate() {
|
||||
if start.elapsed() >= timeout {
|
||||
break;
|
||||
}
|
||||
|
||||
let ep = endpoint.clone();
|
||||
let sni = call_sni.to_string();
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
set.spawn(async move {
|
||||
let result = wzp_transport::connect(&ep, target, &sni, client_cfg).await;
|
||||
(idx, target, result)
|
||||
});
|
||||
|
||||
// Rate limit — don't blast the NAT
|
||||
if idx < targets.len() - 1 {
|
||||
tokio::time::sleep(probe_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
spawned = set.len(),
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: all probes spawned, waiting for first success"
|
||||
);
|
||||
|
||||
// Wait for first success or all failures
|
||||
let deadline = start + timeout;
|
||||
while let Some(join_res) = tokio::select! {
|
||||
r = set.join_next() => r,
|
||||
_ = tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)) => None,
|
||||
} {
|
||||
match join_res {
|
||||
Ok((idx, target, Ok(conn))) => {
|
||||
tracing::info!(
|
||||
idx,
|
||||
%target,
|
||||
remote = %conn.remote_address(),
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: HIT! QUIC handshake succeeded"
|
||||
);
|
||||
set.abort_all();
|
||||
return Some(wzp_transport::QuinnTransport::new(conn));
|
||||
}
|
||||
Ok((idx, target, Err(e))) => {
|
||||
tracing::debug!(
|
||||
idx,
|
||||
%target,
|
||||
error = %e,
|
||||
"birthday: probe failed"
|
||||
);
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: all probes failed or timed out"
|
||||
);
|
||||
None
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generate_targets_known_ports_first() {
|
||||
let ip = Ipv4Addr::new(203, 0, 113, 5);
|
||||
let known = vec![10000, 10001, 10002];
|
||||
let targets = generate_dialer_targets(ip, &known, 10);
|
||||
|
||||
// Known ports should be first
|
||||
assert_eq!(targets[0].port(), 10000);
|
||||
assert_eq!(targets[1].port(), 10001);
|
||||
assert_eq!(targets[2].port(), 10002);
|
||||
// Rest are random
|
||||
assert!(targets.len() <= 10);
|
||||
// All target the right IP
|
||||
assert!(targets.iter().all(|a| a.ip() == std::net::IpAddr::V4(ip)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_no_known_all_random() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let targets = generate_dialer_targets(ip, &[], 50);
|
||||
assert!(!targets.is_empty());
|
||||
assert!(targets.len() <= 50);
|
||||
// All ports in valid range
|
||||
assert!(targets.iter().all(|a| a.port() >= 1024));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_more_known_than_total() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let known: Vec<u16> = (10000..10100).collect();
|
||||
let targets = generate_dialer_targets(ip, &known, 50);
|
||||
// All 100 known ports included even though total=50
|
||||
assert_eq!(targets.len(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_dedup() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let targets = generate_dialer_targets(ip, &[], 100);
|
||||
// No duplicates
|
||||
let mut sorted = targets.clone();
|
||||
sorted.sort();
|
||||
sorted.dedup();
|
||||
assert_eq!(sorted.len(), targets.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config() {
|
||||
let cfg = BirthdayConfig::default();
|
||||
assert_eq!(cfg.acceptor_ports, 32);
|
||||
assert_eq!(cfg.dialer_probes, 128);
|
||||
assert!(cfg.timeout.as_secs() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acceptor_ports_serializes() {
|
||||
let result = AcceptorPorts {
|
||||
external_ip: Some(Ipv4Addr::new(203, 0, 113, 5)),
|
||||
ports: vec![PortMapping {
|
||||
local_port: 12345,
|
||||
external_port: 54321,
|
||||
}],
|
||||
attempted: 32,
|
||||
succeeded: 1,
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("54321"));
|
||||
assert!(json.contains("203.0.113.5"));
|
||||
}
|
||||
}
|
||||
@@ -13,11 +13,11 @@ use wzp_codec::{
|
||||
};
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::packet::QualityReport;
|
||||
use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::packet::QualityReport;
|
||||
use wzp_proto::{CodecId, QualityProfile};
|
||||
use wzp_proto::{CodecId, MediaType, QualityProfile};
|
||||
|
||||
/// Configuration for a call session.
|
||||
pub struct CallConfig {
|
||||
@@ -205,7 +205,7 @@ pub struct CallEncoder {
|
||||
/// Current profile.
|
||||
profile: QualityProfile,
|
||||
/// Outbound sequence counter.
|
||||
seq: u16,
|
||||
seq: u32,
|
||||
/// Current FEC block.
|
||||
block_id: u8,
|
||||
/// Frame index within current block.
|
||||
@@ -234,6 +234,8 @@ pub struct CallEncoder {
|
||||
mini_frames_enabled: bool,
|
||||
/// Frames encoded since the last full header was emitted.
|
||||
frames_since_full: u32,
|
||||
/// Pending quality report to attach to the next source packet.
|
||||
pending_quality_report: Option<QualityReport>,
|
||||
}
|
||||
|
||||
impl CallEncoder {
|
||||
@@ -264,6 +266,7 @@ impl CallEncoder {
|
||||
mini_context: MiniFrameContext::default(),
|
||||
mini_frames_enabled: config.mini_frames_enabled,
|
||||
frames_since_full: 0,
|
||||
pending_quality_report: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,17 +318,15 @@ impl CallEncoder {
|
||||
if self.cn_counter % 10 == 0 {
|
||||
let cn_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::ComfortNoise,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: self.seq,
|
||||
timestamp: self.timestamp_ms,
|
||||
fec_block: self.block_id,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
fec_block: u16::from(self.block_id),
|
||||
},
|
||||
payload: Bytes::from(vec![self.cn_level as u8]),
|
||||
quality_report: None,
|
||||
@@ -351,33 +352,34 @@ impl CallEncoder {
|
||||
// can cleanly identify "no RaptorQ block to assemble" and new
|
||||
// receivers can short-circuit their FEC ingest path.
|
||||
let is_opus = self.profile.codec.is_opus();
|
||||
let (fec_block, fec_symbol, fec_ratio_encoded) = if is_opus {
|
||||
(0u8, 0u8, 0u8)
|
||||
let (fec_block, fec_ratio) = if is_opus {
|
||||
(0u16, 0u8)
|
||||
} else {
|
||||
(
|
||||
self.block_id,
|
||||
self.frame_in_block,
|
||||
u16::from(self.block_id) | (u16::from(self.frame_in_block) << 8),
|
||||
MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||
)
|
||||
};
|
||||
|
||||
// Build source media packet
|
||||
let mut flags = 0u8;
|
||||
if self.pending_quality_report.is_some() {
|
||||
flags |= MediaHeader::FLAG_QUALITY;
|
||||
}
|
||||
let source_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: self.profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded,
|
||||
stream_id: 0,
|
||||
fec_ratio,
|
||||
seq: self.seq,
|
||||
timestamp: self.timestamp_ms,
|
||||
fec_block,
|
||||
fec_symbol,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(encoded.clone()),
|
||||
quality_report: None,
|
||||
quality_report: self.pending_quality_report.take(),
|
||||
};
|
||||
|
||||
self.seq = self.seq.wrapping_add(1);
|
||||
@@ -399,19 +401,15 @@ impl CallEncoder {
|
||||
for (sym_idx, repair_data) in repairs {
|
||||
output.push(MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: true,
|
||||
version: 2,
|
||||
flags: MediaHeader::FLAG_REPAIR,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: self.profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
||||
self.profile.fec_ratio,
|
||||
),
|
||||
stream_id: 0,
|
||||
fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||
seq: self.seq,
|
||||
timestamp: self.timestamp_ms,
|
||||
fec_block: self.block_id,
|
||||
fec_symbol: sym_idx,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
fec_block: u16::from(self.block_id) | (sym_idx << 8),
|
||||
},
|
||||
payload: Bytes::from(repair_data),
|
||||
quality_report: None,
|
||||
@@ -454,6 +452,13 @@ impl CallEncoder {
|
||||
self.audio_enc.set_expected_loss(tuning.expected_loss_pct);
|
||||
}
|
||||
|
||||
/// Queue a quality report for attachment to the next source packet.
|
||||
/// Used by the send task to embed locally-observed path quality so
|
||||
/// the peer can drive adaptive quality switching.
|
||||
pub fn set_pending_quality_report(&mut self, report: QualityReport) {
|
||||
self.pending_quality_report = Some(report);
|
||||
}
|
||||
|
||||
/// Enable or disable acoustic echo cancellation.
|
||||
pub fn set_aec_enabled(&mut self, enabled: bool) {
|
||||
self.aec.set_enabled(enabled);
|
||||
@@ -498,7 +503,7 @@ pub struct CallDecoder {
|
||||
last_good_dred: DredState,
|
||||
/// Sequence number of the packet that produced `last_good_dred`. `None`
|
||||
/// if no packet has yielded DRED state yet (cold start or legacy sender).
|
||||
last_good_dred_seq: Option<u16>,
|
||||
last_good_dred_seq: Option<u32>,
|
||||
/// Phase 4 telemetry counter: gaps recovered via DRED reconstruction.
|
||||
pub dred_reconstructions: u64,
|
||||
/// Phase 4 telemetry counter: gaps filled via classical Opus PLC
|
||||
@@ -561,8 +566,8 @@ impl CallDecoder {
|
||||
if !packet.header.codec_id.is_opus() {
|
||||
let _ = self.fec_dec.add_symbol(
|
||||
packet.header.fec_block,
|
||||
packet.header.fec_symbol,
|
||||
packet.header.is_repair,
|
||||
packet.header.fec_block >> 8,
|
||||
packet.header.is_repair(),
|
||||
&packet.payload,
|
||||
);
|
||||
}
|
||||
@@ -572,7 +577,7 @@ impl CallDecoder {
|
||||
// swap with the cached `last_good_dred` so later gap reconstruction
|
||||
// has fresh neural redundancy to draw from. Parsing happens before
|
||||
// the jitter push because the jitter buffer consumes the packet.
|
||||
if packet.header.codec_id.is_opus() && !packet.header.is_repair {
|
||||
if packet.header.codec_id.is_opus() && !packet.header.is_repair() {
|
||||
match self
|
||||
.dred_decoder
|
||||
.parse_into(&mut self.dred_parse_scratch, &packet.payload)
|
||||
@@ -601,7 +606,7 @@ impl CallDecoder {
|
||||
// Source packets (Opus or Codec2) go to the jitter buffer for decode.
|
||||
// Repair packets never reach the jitter buffer; for Codec2 they're
|
||||
// used by the FEC decoder above, for Opus they're dropped here.
|
||||
if !packet.header.is_repair {
|
||||
if !packet.header.is_repair() {
|
||||
self.jitter.push(packet);
|
||||
}
|
||||
}
|
||||
@@ -636,6 +641,7 @@ impl CallDecoder {
|
||||
fec_ratio: 0.3,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
CodecId::Opus6k => QualityProfile::DEGRADED,
|
||||
CodecId::Opus32k => QualityProfile::STUDIO_32K,
|
||||
@@ -646,9 +652,13 @@ impl CallDecoder {
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC,
|
||||
CodecId::ComfortNoise => QualityProfile::GOOD,
|
||||
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => {
|
||||
panic!("video codec passed to audio decoder")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -701,12 +711,12 @@ impl CallDecoder {
|
||||
if let Some(last_seq) = self.last_good_dred_seq {
|
||||
// How many frames ahead of the missing seq is the
|
||||
// last-good packet? Use wrapping arithmetic for the
|
||||
// u16 seq space.
|
||||
// u32 seq space.
|
||||
let seq_delta = last_seq.wrapping_sub(seq);
|
||||
// Reject stale or backward state. u16 wraparound
|
||||
// Reject stale or backward state. u32 wraparound
|
||||
// would make a "seq went backward" delta very large;
|
||||
// cap at a sane forward-looking window.
|
||||
const MAX_SEQ_DELTA: u16 = 128;
|
||||
const MAX_SEQ_DELTA: u32 = 128;
|
||||
if seq_delta > 0 && seq_delta <= MAX_SEQ_DELTA {
|
||||
let frame_samples =
|
||||
(48_000 * self.profile.frame_duration_ms as i32) / 1000;
|
||||
@@ -775,7 +785,7 @@ impl CallDecoder {
|
||||
/// Phase 3b introspection: sequence number of the most recently parsed
|
||||
/// valid DRED state, or `None` if no Opus packet has yielded DRED data
|
||||
/// yet. Used by tests to debug reconstruction eligibility.
|
||||
pub fn last_good_dred_seq(&self) -> Option<u16> {
|
||||
pub fn last_good_dred_seq(&self) -> Option<u32> {
|
||||
self.last_good_dred_seq
|
||||
}
|
||||
|
||||
@@ -842,7 +852,7 @@ mod tests {
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
assert!(!packets.is_empty());
|
||||
assert_eq!(packets[0].header.seq, 0);
|
||||
assert!(!packets[0].header.is_repair);
|
||||
assert!(!packets[0].header.is_repair());
|
||||
}
|
||||
|
||||
/// Phase 2: Opus packets have zero FEC header fields — no block, no
|
||||
@@ -865,10 +875,9 @@ mod tests {
|
||||
assert_eq!(packets.len(), 1, "Opus must emit exactly 1 source packet");
|
||||
let hdr = &packets[0].header;
|
||||
assert!(hdr.codec_id.is_opus());
|
||||
assert!(!hdr.is_repair);
|
||||
assert!(!hdr.is_repair());
|
||||
assert_eq!(hdr.fec_block, 0, "Opus fec_block must be 0");
|
||||
assert_eq!(hdr.fec_symbol, 0, "Opus fec_symbol must be 0");
|
||||
assert_eq!(hdr.fec_ratio_encoded, 0, "Opus fec_ratio_encoded must be 0");
|
||||
assert_eq!(hdr.fec_ratio, 0, "Opus fec_ratio must be 0");
|
||||
}
|
||||
|
||||
/// Phase 2: Opus never emits repair packets, regardless of how many
|
||||
@@ -892,7 +901,7 @@ mod tests {
|
||||
for _ in 0..20 {
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
total_packets += packets.len();
|
||||
repair_count += packets.iter().filter(|p| p.header.is_repair).count();
|
||||
repair_count += packets.iter().filter(|p| p.header.is_repair()).count();
|
||||
}
|
||||
assert_eq!(repair_count, 0, "Opus must emit zero repair packets");
|
||||
assert_eq!(
|
||||
@@ -924,7 +933,7 @@ mod tests {
|
||||
for _ in 0..16 {
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
for p in &packets {
|
||||
if p.header.is_repair {
|
||||
if p.header.is_repair() {
|
||||
repair_count += 1;
|
||||
}
|
||||
}
|
||||
@@ -943,17 +952,15 @@ mod tests {
|
||||
|
||||
let pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0u8; 60]),
|
||||
quality_report: None,
|
||||
@@ -1015,17 +1022,15 @@ mod tests {
|
||||
encoded.truncate(n);
|
||||
let pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq: i,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: i as u32,
|
||||
timestamp: (i as u32) * 20,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(encoded),
|
||||
quality_report: None,
|
||||
@@ -1095,9 +1100,7 @@ mod tests {
|
||||
|
||||
let dred_delta = dec.dred_reconstructions - baseline_dred;
|
||||
let plc_delta = dec.classical_plc_invocations - baseline_plc;
|
||||
eprintln!(
|
||||
"[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}"
|
||||
);
|
||||
eprintln!("[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}");
|
||||
assert!(
|
||||
dred_delta >= 1,
|
||||
"expected ≥1 DRED reconstruction on single-packet loss, \
|
||||
@@ -1158,7 +1161,7 @@ mod tests {
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
for pkt in packets {
|
||||
// Drop every 5th source packet to simulate loss.
|
||||
if !pkt.header.is_repair && i % 5 == 3 {
|
||||
if !pkt.header.is_repair() && i % 5 == 3 {
|
||||
continue;
|
||||
}
|
||||
dec.ingest(pkt);
|
||||
@@ -1312,20 +1315,18 @@ mod tests {
|
||||
|
||||
// ---- JitterStats telemetry tests ----
|
||||
|
||||
fn make_test_packet(seq: u16) -> MediaPacket {
|
||||
fn make_test_packet(seq: u32) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq as u32 * 20,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
fec_symbol: seq as u8,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0u8; 60]),
|
||||
quality_report: None,
|
||||
@@ -1337,7 +1338,7 @@ mod tests {
|
||||
let config = CallConfig::default();
|
||||
let mut dec = CallDecoder::new(&config);
|
||||
|
||||
for i in 0..5u16 {
|
||||
for i in 0..5u32 {
|
||||
dec.ingest(make_test_packet(i));
|
||||
}
|
||||
|
||||
@@ -1367,7 +1368,7 @@ mod tests {
|
||||
let mut dec = CallDecoder::new(&config);
|
||||
|
||||
// Generate some stats: ingest packets and trigger underruns on empty buffer
|
||||
for i in 0..3u16 {
|
||||
for i in 0..3u32 {
|
||||
dec.ingest(make_test_packet(i));
|
||||
}
|
||||
// Also call decode on empty decoder to get underruns
|
||||
@@ -1446,10 +1447,7 @@ mod tests {
|
||||
cn_packets >= 1,
|
||||
"should have at least one CN packet, got {cn_packets}"
|
||||
);
|
||||
assert!(
|
||||
enc.frames_suppressed > 0,
|
||||
"frames_suppressed should be > 0"
|
||||
);
|
||||
assert!(enc.frames_suppressed > 0, "frames_suppressed should be > 0");
|
||||
}
|
||||
|
||||
// ---- DredTuner integration tests ----
|
||||
@@ -1496,7 +1494,10 @@ mod tests {
|
||||
// Verify the encoder still works after tuning.
|
||||
let pcm = voice_frame_20ms(0);
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
assert!(!packets.is_empty(), "encoder must still produce packets after DRED tuning");
|
||||
assert!(
|
||||
!packets.is_empty(),
|
||||
"encoder must still produce packets after DRED tuning"
|
||||
);
|
||||
}
|
||||
|
||||
/// DredTuner jitter spike triggers pre-emptive DRED boost to ceiling.
|
||||
@@ -1514,11 +1515,15 @@ mod tests {
|
||||
|
||||
// Jitter spikes to 40ms (8x baseline of ~5ms).
|
||||
let tuning = tuner.update(0.0, 50, 40);
|
||||
assert!(tuner.spike_boost_active(), "jitter spike should activate boost");
|
||||
assert!(
|
||||
tuner.spike_boost_active(),
|
||||
"jitter spike should activate boost"
|
||||
);
|
||||
assert!(tuning.is_some());
|
||||
// Ceiling for Opus24k is 50 frames = 500 ms.
|
||||
assert_eq!(
|
||||
tuning.unwrap().dred_frames, 50,
|
||||
tuning.unwrap().dred_frames,
|
||||
50,
|
||||
"spike should push to ceiling"
|
||||
);
|
||||
}
|
||||
@@ -1578,4 +1583,89 @@ mod tests {
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
assert!(!packets.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoder_attaches_quality_report() {
|
||||
let mut enc = CallEncoder::new(&CallConfig {
|
||||
profile: QualityProfile::GOOD,
|
||||
suppression_enabled: false,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Set a quality report
|
||||
enc.set_pending_quality_report(QualityReport::from_path_stats(5.0, 80, 10));
|
||||
|
||||
// Encode a frame — should have quality_report attached
|
||||
let pcm = voice_frame_20ms(0);
|
||||
let packets = enc.encode_frame(&pcm).unwrap();
|
||||
assert!(!packets.is_empty());
|
||||
assert!(
|
||||
packets[0].header.has_quality(),
|
||||
"first packet should have quality report"
|
||||
);
|
||||
assert!(packets[0].quality_report.is_some());
|
||||
|
||||
// Next frame should NOT have quality_report (it was consumed)
|
||||
let packets2 = enc.encode_frame(&voice_frame_20ms(960)).unwrap();
|
||||
assert!(
|
||||
!packets2[0].header.has_quality(),
|
||||
"second packet should not have quality report"
|
||||
);
|
||||
assert!(packets2[0].quality_report.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quality_report_aead_tamper_fails_decrypt() {
|
||||
use wzp_crypto::ChaChaSession;
|
||||
use wzp_proto::CryptoSession;
|
||||
|
||||
// Build a packet with a QualityReport trailer.
|
||||
let pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 2,
|
||||
flags: MediaHeader::FLAG_QUALITY,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 10,
|
||||
seq: 42,
|
||||
timestamp: 1000,
|
||||
fec_block: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0xAB; 60]),
|
||||
quality_report: Some(QualityReport::from_path_stats(5.0, 80, 10)),
|
||||
};
|
||||
|
||||
// Serialize: header || payload || quality_report
|
||||
let wire = pkt.to_bytes();
|
||||
assert_eq!(
|
||||
wire.len(),
|
||||
MediaHeader::WIRE_SIZE + pkt.payload.len() + QualityReport::WIRE_SIZE
|
||||
);
|
||||
|
||||
let header_bytes = &wire[..MediaHeader::WIRE_SIZE];
|
||||
let plaintext = &wire[MediaHeader::WIRE_SIZE..];
|
||||
|
||||
// Encrypt with ChaCha20-Poly1305 (header as AAD, payload+QR as plaintext).
|
||||
let mut alice = ChaChaSession::new([0xAA; 32]);
|
||||
let mut bob = ChaChaSession::new([0xAA; 32]);
|
||||
let mut ciphertext = Vec::new();
|
||||
alice
|
||||
.encrypt(header_bytes, plaintext, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
// Tamper with a byte in the QualityReport region (last 4 bytes of plaintext
|
||||
// → last 4 bytes of ciphertext for ChaCha20 stream cipher).
|
||||
let qr_offset_in_plaintext = plaintext.len() - QualityReport::WIRE_SIZE;
|
||||
let tamper_idx = qr_offset_in_plaintext;
|
||||
ciphertext[tamper_idx] ^= 0xFF;
|
||||
|
||||
// Decryption must fail because the AEAD tag no longer matches.
|
||||
let mut decrypted = Vec::new();
|
||||
let result = bob.decrypt(header_bytes, &ciphertext, &mut decrypted);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"tampering with QualityReport inside AEAD payload must cause decryption failure"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
use wzp_proto::MediaTransport;
|
||||
use wzp_proto::{MediaTransport, default_signal_version};
|
||||
|
||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||
|
||||
@@ -52,6 +52,8 @@ struct CliArgs {
|
||||
signal: bool,
|
||||
/// Place a direct call to a fingerprint (requires --signal).
|
||||
call_target: Option<String>,
|
||||
/// Run network diagnostic (STUN, port mapping, relay latencies).
|
||||
netcheck: bool,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
@@ -97,6 +99,7 @@ fn parse_args() -> CliArgs {
|
||||
let mut relay_str = None;
|
||||
let mut signal = false;
|
||||
let mut call_target = None;
|
||||
let mut netcheck = false;
|
||||
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
@@ -105,7 +108,11 @@ fn parse_args() -> CliArgs {
|
||||
"--signal" => signal = true,
|
||||
"--call" => {
|
||||
i += 1;
|
||||
call_target = Some(args.get(i).expect("--call requires a fingerprint").to_string());
|
||||
call_target = Some(
|
||||
args.get(i)
|
||||
.expect("--call requires a fingerprint")
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
"--send-tone" => {
|
||||
i += 1;
|
||||
@@ -182,7 +189,12 @@ fn parse_args() -> CliArgs {
|
||||
);
|
||||
}
|
||||
"--sweep" => sweep = true,
|
||||
"--version-check" => { version_check = true; }
|
||||
"--netcheck" => {
|
||||
netcheck = true;
|
||||
}
|
||||
"--version-check" => {
|
||||
version_check = true;
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
eprintln!("Usage: wzp-client [options] [relay-addr]");
|
||||
eprintln!();
|
||||
@@ -193,13 +205,19 @@ fn parse_args() -> CliArgs {
|
||||
eprintln!(" --record <file.raw> Record received audio to raw PCM file");
|
||||
eprintln!(" --echo-test <secs> Run automated echo quality test");
|
||||
eprintln!(" --drift-test <secs> Run automated clock-drift measurement");
|
||||
eprintln!(" --sweep Run jitter buffer parameter sweep (local, no network)");
|
||||
eprintln!(" --seed <hex> Identity seed (64 hex chars, featherChat compatible)");
|
||||
eprintln!(
|
||||
" --sweep Run jitter buffer parameter sweep (local, no network)"
|
||||
);
|
||||
eprintln!(
|
||||
" --seed <hex> Identity seed (64 hex chars, featherChat compatible)"
|
||||
);
|
||||
eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)");
|
||||
eprintln!(" --room <name> Room name (hashed for privacy before sending)");
|
||||
eprintln!(" --token <token> featherChat bearer token for relay auth");
|
||||
eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)");
|
||||
eprintln!(" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)");
|
||||
eprintln!(
|
||||
" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)"
|
||||
);
|
||||
eprintln!();
|
||||
eprintln!("Default relay: 127.0.0.1:4433");
|
||||
std::process::exit(0);
|
||||
@@ -238,6 +256,7 @@ fn parse_args() -> CliArgs {
|
||||
version_check,
|
||||
signal,
|
||||
call_target,
|
||||
netcheck,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,12 +275,28 @@ async fn main() -> anyhow::Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// --netcheck: run network diagnostic and exit
|
||||
if cli.netcheck {
|
||||
let config = wzp_client::netcheck::NetcheckConfig {
|
||||
stun_config: wzp_client::stun::StunConfig::default(),
|
||||
relays: vec![("relay".into(), cli.relay_addr)],
|
||||
timeout: std::time::Duration::from_secs(5),
|
||||
test_portmap: true,
|
||||
test_ipv6: true,
|
||||
local_port: 0,
|
||||
};
|
||||
let report = wzp_client::netcheck::run_netcheck(&config).await;
|
||||
print!("{}", wzp_client::netcheck::format_report(&report));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// --version-check: query relay version over QUIC and exit
|
||||
if cli.version_check {
|
||||
let client_config = wzp_transport::client_config();
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse()?;
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let conn = wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
|
||||
let conn =
|
||||
wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
|
||||
match conn.accept_uni().await {
|
||||
Ok(mut recv) => {
|
||||
let data = recv.read_to_end(256).await.unwrap_or_default();
|
||||
@@ -269,7 +304,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
println!("{} {}", cli.relay_addr, version.trim());
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("relay {} does not support version query: {e}", cli.relay_addr);
|
||||
eprintln!(
|
||||
"relay {} does not support version query: {e}",
|
||||
cli.relay_addr
|
||||
);
|
||||
}
|
||||
}
|
||||
endpoint.close(0u32.into(), b"done");
|
||||
@@ -309,8 +347,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
"0.0.0.0:0".parse()?
|
||||
};
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let connection =
|
||||
wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
||||
let connection = wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
||||
|
||||
info!("Connected to relay");
|
||||
|
||||
@@ -321,10 +358,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
{
|
||||
let shutdown_transport = transport.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM handler");
|
||||
let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
|
||||
.expect("failed to register SIGINT handler");
|
||||
let mut sigterm =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM handler");
|
||||
let mut sigint =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
|
||||
.expect("failed to register SIGINT handler");
|
||||
tokio::select! {
|
||||
_ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); }
|
||||
_ = sigint.recv() => { info!("SIGINT received, closing connection..."); }
|
||||
@@ -332,13 +371,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
// Close the QUIC connection immediately (APPLICATION_CLOSE frame).
|
||||
// Don't call process::exit — let the main task detect the closed
|
||||
// connection and perform clean shutdown (e.g., save recordings).
|
||||
shutdown_transport.connection().close(0u32.into(), b"shutdown");
|
||||
shutdown_transport
|
||||
.connection()
|
||||
.close(0u32.into(), b"shutdown");
|
||||
});
|
||||
}
|
||||
|
||||
// Send auth token if provided (relay with --auth-url expects this first)
|
||||
if let Some(ref token) = cli.token {
|
||||
let auth = wzp_proto::SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: token.clone(),
|
||||
};
|
||||
transport.send_signal(&auth).await?;
|
||||
@@ -346,21 +388,29 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
// Crypto handshake — establishes verified identity + session key
|
||||
let _crypto_session = wzp_client::handshake::perform_handshake(
|
||||
let hs = wzp_client::handshake::perform_handshake(
|
||||
&*transport,
|
||||
&seed.0,
|
||||
None, // alias — desktop client doesn't set one yet
|
||||
).await?;
|
||||
info!("crypto handshake complete");
|
||||
)
|
||||
.await?;
|
||||
info!(video_codec = ?hs.video_codec, "crypto handshake complete");
|
||||
|
||||
// Wrap the transport so all media I/O goes through AEAD encryption.
|
||||
let enc_transport: Arc<dyn wzp_proto::MediaTransport> = Arc::new(
|
||||
wzp_client::encrypted_transport::EncryptingTransport::new(transport.clone(), hs.session),
|
||||
);
|
||||
|
||||
if cli.live {
|
||||
#[cfg(feature = "audio")]
|
||||
{
|
||||
return run_live(transport).await;
|
||||
return run_live(enc_transport).await;
|
||||
}
|
||||
#[cfg(not(feature = "audio"))]
|
||||
{
|
||||
anyhow::bail!("--live requires the 'audio' feature (build with: cargo build --features audio)");
|
||||
anyhow::bail!(
|
||||
"--live requires the 'audio' feature (build with: cargo build --features audio)"
|
||||
);
|
||||
}
|
||||
} else if let Some(secs) = cli.echo_test_secs {
|
||||
let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?;
|
||||
@@ -377,14 +427,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
transport.close().await?;
|
||||
Ok(())
|
||||
} else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() {
|
||||
run_file_mode(transport, cli.send_tone_secs, cli.send_file, cli.record_file).await
|
||||
run_file_mode(
|
||||
enc_transport,
|
||||
cli.send_tone_secs,
|
||||
cli.send_file,
|
||||
cli.record_file,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
run_silence(transport).await
|
||||
run_silence(enc_transport).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send silence frames (connectivity test).
|
||||
async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
||||
async fn run_silence(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||
let config = CallConfig::default();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
|
||||
@@ -398,7 +454,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
for i in 0..250u32 {
|
||||
let packets = encoder.encode_frame(&pcm)?;
|
||||
for pkt in &packets {
|
||||
if pkt.header.is_repair {
|
||||
if pkt.header.is_repair() {
|
||||
total_repair += 1;
|
||||
} else {
|
||||
total_source += 1;
|
||||
@@ -423,6 +479,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
|
||||
info!(total_source, total_repair, total_bytes, "done — closing");
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
@@ -433,7 +490,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
|
||||
/// File/tone mode: send a test tone or audio file, and/or record received audio.
|
||||
async fn run_file_mode(
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
transport: Arc<dyn wzp_proto::MediaTransport>,
|
||||
send_tone_secs: Option<u32>,
|
||||
send_file: Option<String>,
|
||||
record_file: Option<String>,
|
||||
@@ -448,21 +505,28 @@ async fn run_file_mode(
|
||||
// Read raw PCM file (48kHz mono s16le)
|
||||
let bytes = match std::fs::read(path) {
|
||||
Ok(b) => b,
|
||||
Err(e) => { error!("read {path}: {e}"); return; }
|
||||
Err(e) => {
|
||||
error!("read {path}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let samples: Vec<i16> = bytes.chunks_exact(2)
|
||||
let samples: Vec<i16> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
||||
.collect();
|
||||
let duration = samples.len() as f64 / 48_000.0;
|
||||
info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file");
|
||||
samples.chunks(FRAME_SAMPLES)
|
||||
samples
|
||||
.chunks(FRAME_SAMPLES)
|
||||
.filter(|c| c.len() == FRAME_SAMPLES)
|
||||
.map(|c| c.to_vec())
|
||||
.collect()
|
||||
} else if let Some(secs) = send_tone_secs {
|
||||
let total = (secs as u64) * 50;
|
||||
info!(seconds = secs, frames = total, "sending 440Hz tone");
|
||||
(0..total).map(|i| generate_sine_frame(440.0, 48_000, i)).collect()
|
||||
(0..total)
|
||||
.map(|i| generate_sine_frame(440.0, 48_000, i))
|
||||
.collect()
|
||||
} else {
|
||||
// No sending, just wait
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
@@ -486,7 +550,7 @@ async fn run_file_mode(
|
||||
}
|
||||
};
|
||||
for pkt in &packets {
|
||||
if pkt.header.is_repair {
|
||||
if pkt.header.is_repair() {
|
||||
total_repair += 1;
|
||||
} else {
|
||||
total_source += 1;
|
||||
@@ -534,7 +598,7 @@ async fn run_file_mode(
|
||||
result = recv_transport.recv_media() => {
|
||||
match result {
|
||||
Ok(Some(pkt)) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -575,6 +639,7 @@ async fn run_file_mode(
|
||||
|
||||
// Send Hangup signal so the relay knows we're done
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
@@ -614,7 +679,7 @@ async fn run_file_mode(
|
||||
|
||||
/// Live mode: capture from mic, encode, send; receive, decode, play.
|
||||
#[cfg(feature = "audio")]
|
||||
async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
||||
async fn run_live(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||
use wzp_client::audio_io::{AudioCapture, AudioPlayback};
|
||||
|
||||
let capture = AudioCapture::start()?;
|
||||
@@ -667,7 +732,7 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
|
||||
loop {
|
||||
match recv_transport.recv_media().await {
|
||||
Ok(Some(pkt)) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
// Only decode for source packets (1 source = 1 audio frame).
|
||||
// Repair packets feed the FEC decoder but don't produce audio.
|
||||
@@ -712,7 +777,7 @@ async fn run_signal_mode(
|
||||
token: Option<String>,
|
||||
call_target: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
use wzp_proto::SignalMessage;
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
let identity = seed.derive_identity();
|
||||
let pub_id = identity.public_identity();
|
||||
@@ -734,22 +799,34 @@ async fn run_signal_mode(
|
||||
|
||||
// Auth if token provided
|
||||
if let Some(ref tok) = token {
|
||||
transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await?;
|
||||
transport
|
||||
.send_signal(&SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: tok.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Register presence (signature not verified in Phase 1)
|
||||
transport.send_signal(&SignalMessage::RegisterPresence {
|
||||
identity_pub,
|
||||
signature: vec![], // Phase 1: not verified
|
||||
alias: None,
|
||||
}).await?;
|
||||
transport
|
||||
.send_signal(&SignalMessage::RegisterPresence {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
signature: vec![], // Phase 1: not verified
|
||||
alias: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Wait for ack
|
||||
match transport.recv_signal().await? {
|
||||
Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => {
|
||||
info!(fingerprint = %fp, "registered on relay — waiting for calls");
|
||||
}
|
||||
Some(SignalMessage::RegisterPresenceAck { success: false, error, .. }) => {
|
||||
Some(SignalMessage::RegisterPresenceAck {
|
||||
success: false,
|
||||
error,
|
||||
..
|
||||
}) => {
|
||||
anyhow::bail!("registration failed: {}", error.unwrap_or_default());
|
||||
}
|
||||
other => {
|
||||
@@ -760,24 +837,33 @@ async fn run_signal_mode(
|
||||
// If --call specified, place the call
|
||||
if let Some(ref target) = call_target {
|
||||
info!(target = %target, "placing direct call...");
|
||||
let call_id = format!("{:016x}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
|
||||
let call_id = format!(
|
||||
"{:016x}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos()
|
||||
);
|
||||
|
||||
transport.send_signal(&SignalMessage::DirectCallOffer {
|
||||
caller_fingerprint: fp.clone(),
|
||||
caller_alias: None,
|
||||
target_fingerprint: target.clone(),
|
||||
call_id: call_id.clone(),
|
||||
identity_pub,
|
||||
ephemeral_pub: [0u8; 32], // Phase 1: not used for key exchange
|
||||
signature: vec![],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
// CLI client doesn't attempt hole-punching; always
|
||||
// relay-path.
|
||||
caller_reflexive_addr: None,
|
||||
caller_local_addrs: Vec::new(),
|
||||
caller_build_version: None,
|
||||
}).await?;
|
||||
transport
|
||||
.send_signal(&SignalMessage::DirectCallOffer {
|
||||
version: default_signal_version(),
|
||||
caller_fingerprint: fp.clone(),
|
||||
caller_alias: None,
|
||||
target_fingerprint: target.clone(),
|
||||
call_id: call_id.clone(),
|
||||
identity_pub,
|
||||
ephemeral_pub: [0u8; 32], // Phase 1: not used for key exchange
|
||||
signature: vec![],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
// CLI client doesn't attempt hole-punching; always
|
||||
// relay-path.
|
||||
caller_reflexive_addr: None,
|
||||
caller_local_addrs: Vec::new(),
|
||||
caller_mapped_addr: None,
|
||||
caller_build_version: None,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Signal recv loop — handle incoming signals
|
||||
@@ -788,10 +874,15 @@ async fn run_signal_mode(
|
||||
loop {
|
||||
match signal_transport.recv_signal().await {
|
||||
Ok(Some(msg)) => match msg {
|
||||
SignalMessage::CallRinging { call_id } => {
|
||||
SignalMessage::CallRinging { call_id, .. } => {
|
||||
info!(call_id = %call_id, "ringing...");
|
||||
}
|
||||
SignalMessage::DirectCallOffer { caller_fingerprint, caller_alias, call_id, .. } => {
|
||||
SignalMessage::DirectCallOffer {
|
||||
caller_fingerprint,
|
||||
caller_alias,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
info!(
|
||||
from = %caller_fingerprint,
|
||||
alias = ?caller_alias,
|
||||
@@ -799,24 +890,40 @@ async fn run_signal_mode(
|
||||
"incoming call — auto-accepting (generic)"
|
||||
);
|
||||
// Auto-accept for CLI testing
|
||||
let _ = signal_transport.send_signal(&SignalMessage::DirectCallAnswer {
|
||||
call_id,
|
||||
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||
identity_pub: Some(identity_pub),
|
||||
ephemeral_pub: None,
|
||||
signature: None,
|
||||
chosen_profile: Some(wzp_proto::QualityProfile::GOOD),
|
||||
// CLI auto-accept uses generic (privacy) mode,
|
||||
// so callee addr stays hidden from the caller.
|
||||
callee_reflexive_addr: None,
|
||||
callee_local_addrs: Vec::new(),
|
||||
callee_build_version: None,
|
||||
}).await;
|
||||
let _ = signal_transport
|
||||
.send_signal(&SignalMessage::DirectCallAnswer {
|
||||
version: default_signal_version(),
|
||||
call_id,
|
||||
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||
identity_pub: Some(identity_pub),
|
||||
ephemeral_pub: None,
|
||||
signature: None,
|
||||
chosen_profile: Some(wzp_proto::QualityProfile::GOOD),
|
||||
// CLI auto-accept uses generic (privacy) mode,
|
||||
// so callee addr stays hidden from the caller.
|
||||
callee_reflexive_addr: None,
|
||||
callee_local_addrs: Vec::new(),
|
||||
callee_mapped_addr: None,
|
||||
callee_build_version: None,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
SignalMessage::DirectCallAnswer { call_id, accept_mode, .. } => {
|
||||
SignalMessage::DirectCallAnswer {
|
||||
call_id,
|
||||
accept_mode,
|
||||
..
|
||||
} => {
|
||||
info!(call_id = %call_id, mode = ?accept_mode, "call answered");
|
||||
}
|
||||
SignalMessage::CallSetup { call_id, room, relay_addr: setup_relay, peer_direct_addr: _, peer_local_addrs: _ } => {
|
||||
SignalMessage::CallSetup {
|
||||
call_id,
|
||||
room,
|
||||
relay_addr: setup_relay,
|
||||
peer_direct_addr: _,
|
||||
peer_local_addrs: _,
|
||||
peer_mapped_addr: _,
|
||||
..
|
||||
} => {
|
||||
info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room");
|
||||
|
||||
// Connect to the media room
|
||||
@@ -824,18 +931,28 @@ async fn run_signal_mode(
|
||||
let media_cfg = wzp_transport::client_config();
|
||||
match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await {
|
||||
Ok(media_conn) => {
|
||||
let media_transport = Arc::new(wzp_transport::QuinnTransport::new(media_conn));
|
||||
let media_transport =
|
||||
Arc::new(wzp_transport::QuinnTransport::new(media_conn));
|
||||
|
||||
// Crypto handshake
|
||||
match wzp_client::handshake::perform_handshake(&*media_transport, &my_seed, None).await {
|
||||
Ok(_session) => {
|
||||
info!("media connected — sending tone (press Ctrl+C to hang up)");
|
||||
match wzp_client::handshake::perform_handshake(
|
||||
&*media_transport,
|
||||
&my_seed,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_hs) => {
|
||||
info!(
|
||||
"media connected — sending tone (press Ctrl+C to hang up)"
|
||||
);
|
||||
|
||||
// Simple tone sender for testing
|
||||
let mt = media_transport.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
let config = wzp_client::call::CallConfig::default();
|
||||
let mut encoder = wzp_client::call::CallEncoder::new(&config);
|
||||
let mut encoder =
|
||||
wzp_client::call::CallEncoder::new(&config);
|
||||
let duration = tokio::time::Duration::from_millis(20);
|
||||
loop {
|
||||
let pcm: Vec<i16> = (0..FRAME_SAMPLES)
|
||||
@@ -843,7 +960,9 @@ async fn run_signal_mode(
|
||||
.collect();
|
||||
if let Ok(pkts) = encoder.encode_frame(&pcm) {
|
||||
for pkt in &pkts {
|
||||
if mt.send_media(pkt).await.is_err() { return; }
|
||||
if mt.send_media(pkt).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(duration).await;
|
||||
@@ -866,6 +985,7 @@ async fn run_signal_mode(
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("hanging up...");
|
||||
let _ = signal_transport.send_signal(&SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
}).await;
|
||||
|
||||
@@ -144,7 +144,7 @@ pub async fn run_drift_test(
|
||||
}
|
||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -180,7 +180,7 @@ pub async fn run_drift_test(
|
||||
while Instant::now() < drain_deadline {
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -234,7 +234,10 @@ pub fn print_drift_report(result: &DriftResult) {
|
||||
println!();
|
||||
println!("Expected duration: {} ms", result.expected_duration_ms);
|
||||
println!("Actual duration: {} ms", result.actual_duration_ms);
|
||||
println!("Drift: {} ms ({:+.4}%)", result.drift_ms, result.drift_pct);
|
||||
println!(
|
||||
"Drift: {} ms ({:+.4}%)",
|
||||
result.drift_ms, result.drift_pct
|
||||
);
|
||||
println!();
|
||||
|
||||
// Interpretation
|
||||
@@ -246,9 +249,15 @@ pub fn print_drift_report(result: &DriftResult) {
|
||||
} else if abs_drift < 20 {
|
||||
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
|
||||
} else if abs_drift < 100 {
|
||||
println!("Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.", abs_drift);
|
||||
println!(
|
||||
"Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.",
|
||||
abs_drift
|
||||
);
|
||||
} else {
|
||||
println!("Result: POOR -- significant drift ({} ms). Investigate clock sources.", abs_drift);
|
||||
println!(
|
||||
"Result: POOR -- significant drift ({} ms). Investigate clock sources.",
|
||||
abs_drift
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
@@ -38,6 +38,15 @@ pub enum WinningPath {
|
||||
Relay,
|
||||
}
|
||||
|
||||
/// Diagnostic info for a single candidate dial attempt.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct CandidateDiag {
|
||||
pub index: usize,
|
||||
pub addr: String,
|
||||
pub result: String, // "ok", "skipped:ipv6", "error:..."
|
||||
pub elapsed_ms: Option<u32>,
|
||||
}
|
||||
|
||||
/// Phase 6: the race now returns BOTH transports (when available)
|
||||
/// so the connect command can negotiate with the peer before
|
||||
/// committing. The negotiation decides which transport to use
|
||||
@@ -54,6 +63,8 @@ pub struct RaceResult {
|
||||
/// Informational — the actual path used is decided by the
|
||||
/// Phase 6 negotiation after both sides exchange reports.
|
||||
pub local_winner: WinningPath,
|
||||
/// Per-candidate diagnostic info for debugging.
|
||||
pub candidate_diags: Vec<CandidateDiag>,
|
||||
}
|
||||
|
||||
/// Attempt a direct QUIC connection to the peer in parallel with
|
||||
@@ -88,19 +99,30 @@ pub struct PeerCandidates {
|
||||
/// same-LAN pairs — direct dials to these bypass the NAT
|
||||
/// entirely.
|
||||
pub local: Vec<SocketAddr>,
|
||||
/// Phase 8 (Tailscale-inspired): peer's port-mapped external
|
||||
/// address from NAT-PMP/PCP/UPnP. When the router supports
|
||||
/// port mapping, this gives a stable external address even
|
||||
/// behind symmetric NATs.
|
||||
pub mapped: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl PeerCandidates {
|
||||
/// Flatten into the list of addrs the D-role should dial.
|
||||
/// Order: LAN host candidates first (fastest when they
|
||||
/// work), then reflexive (covers the non-LAN case).
|
||||
/// work), then port-mapped (stable even behind symmetric
|
||||
/// NATs), then reflexive (covers the non-LAN case).
|
||||
pub fn dial_order(&self) -> Vec<SocketAddr> {
|
||||
let mut out = Vec::with_capacity(self.local.len() + 1);
|
||||
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||
out.extend(self.local.iter().copied());
|
||||
// Port-mapped address goes before reflexive — it's
|
||||
// more reliable on symmetric NATs where the reflexive
|
||||
// addr might not match what the peer actually sees.
|
||||
if let Some(a) = self.mapped {
|
||||
if !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
if let Some(a) = self.reflexive {
|
||||
// Only add if it's not already in the list (some
|
||||
// edge cases on same-LAN could have the same addr
|
||||
// in both).
|
||||
if !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
@@ -108,10 +130,54 @@ impl PeerCandidates {
|
||||
out
|
||||
}
|
||||
|
||||
/// Smart dial order: filters out candidates that can't possibly
|
||||
/// work given our own reflexive address.
|
||||
///
|
||||
/// - **LAN candidates**: only included if peer's public IP
|
||||
/// matches ours (same network). Private IPs are unreachable
|
||||
/// cross-network.
|
||||
/// - **IPv6 candidates**: stripped entirely (Phase 7 disabled).
|
||||
/// - **Reflexive + mapped**: always included.
|
||||
pub fn smart_dial_order(&self, own_reflexive: Option<&SocketAddr>) -> Vec<SocketAddr> {
|
||||
let own_public_ip = own_reflexive.map(|a| a.ip());
|
||||
let peer_public_ip = self.reflexive.map(|a| a.ip());
|
||||
let same_network = match (own_public_ip, peer_public_ip) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||
|
||||
// LAN candidates only when on the same network.
|
||||
if same_network {
|
||||
for addr in &self.local {
|
||||
if !addr.is_ipv6() {
|
||||
out.push(*addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Port-mapped (always useful — it's a public addr).
|
||||
if let Some(a) = self.mapped {
|
||||
if !a.is_ipv6() && !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
|
||||
// Reflexive (always useful — it's the peer's public addr).
|
||||
if let Some(a) = self.reflexive {
|
||||
if !a.is_ipv6() && !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Is there anything for the D-role to dial? If not, the
|
||||
/// race reduces to relay-only.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.reflexive.is_none() && self.local.is_empty()
|
||||
self.reflexive.is_none() && self.local.is_empty() && self.mapped.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +188,9 @@ pub async fn race(
|
||||
relay_addr: SocketAddr,
|
||||
room_sni: String,
|
||||
call_sni: String,
|
||||
// Our own reflexive address — used to filter LAN candidates
|
||||
// that can't work cross-network.
|
||||
own_reflexive: Option<SocketAddr>,
|
||||
// Phase 5: when `Some`, reuse this endpoint for BOTH the
|
||||
// direct-path branch AND the relay dial. Pass the signal
|
||||
// endpoint. The endpoint MUST be server-capable (created
|
||||
@@ -141,6 +210,10 @@ pub async fn race(
|
||||
// is created. Install attempt is idempotent.
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Shared diagnostic collector for per-candidate results.
|
||||
let diags_collector: Arc<std::sync::Mutex<Vec<CandidateDiag>>> =
|
||||
Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
|
||||
// Build the direct-path endpoint + future based on role.
|
||||
//
|
||||
// A-role: one accept future on the shared endpoint. The
|
||||
@@ -196,7 +269,84 @@ pub async fn race(
|
||||
// as dial — IPv6 connections die on datagram send).
|
||||
// Accept on IPv4 shared endpoint only.
|
||||
let _v6_ep_unused = ipv6_endpoint.clone();
|
||||
// Collect peer addrs for NAT tickle (Acceptor-side).
|
||||
let tickle_addrs: Vec<SocketAddr> = peer_candidates
|
||||
.smart_dial_order(own_reflexive.as_ref())
|
||||
.into_iter()
|
||||
.filter(|a| !a.ip().is_loopback() && !a.ip().is_unspecified())
|
||||
.collect();
|
||||
direct_fut = Box::pin(async move {
|
||||
// NAT tickle: send a small UDP packet to each of the
|
||||
// Dialer's candidate addresses FROM our shared endpoint.
|
||||
// This opens our NAT's pinhole for return traffic from
|
||||
// those IPs — critical for address-restricted NATs that
|
||||
// only allow inbound from IPs they've seen outbound
|
||||
// traffic to. Without this, the Dialer's QUIC Initial
|
||||
// gets dropped by our NAT.
|
||||
if !tickle_addrs.is_empty() {
|
||||
if let Ok(local_addr) = ep_for_fut.local_addr() {
|
||||
// Send a tickle to each peer candidate address
|
||||
// to open our NAT for return traffic from that IP.
|
||||
//
|
||||
// We use a socket2 socket with SO_REUSEADDR +
|
||||
// SO_REUSEPORT on the SAME port as the quinn
|
||||
// endpoint. This is necessary because quinn
|
||||
// already holds the port — a plain bind() would
|
||||
// fail with EADDRINUSE.
|
||||
let tickle_result: Result<(), String> = (|| {
|
||||
use std::net::UdpSocket as StdUdpSocket;
|
||||
let sock = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.map_err(|e| format!("socket: {e}"))?;
|
||||
sock.set_reuse_address(true)
|
||||
.map_err(|e| format!("reuseaddr: {e}"))?;
|
||||
// macOS/BSD/Linux also need SO_REUSEPORT
|
||||
#[cfg(any(
|
||||
target_os = "macos",
|
||||
target_os = "linux",
|
||||
target_os = "android"
|
||||
))]
|
||||
{
|
||||
// socket2 exposes set_reuse_port on unix
|
||||
unsafe {
|
||||
let optval: libc::c_int = 1;
|
||||
libc::setsockopt(
|
||||
std::os::unix::io::AsRawFd::as_raw_fd(&sock),
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_REUSEPORT,
|
||||
&optval as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||
);
|
||||
}
|
||||
}
|
||||
sock.set_nonblocking(true)
|
||||
.map_err(|e| format!("nonblock: {e}"))?;
|
||||
let bind_addr: SocketAddr = SocketAddr::new(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
|
||||
local_addr.port(),
|
||||
);
|
||||
sock.bind(&bind_addr.into())
|
||||
.map_err(|e| format!("bind :{}: {e}", local_addr.port()))?;
|
||||
let std_sock: StdUdpSocket = sock.into();
|
||||
for addr in &tickle_addrs {
|
||||
let _ = std_sock.send_to(&[0u8; 1], addr);
|
||||
tracing::info!(
|
||||
%addr,
|
||||
local_port = local_addr.port(),
|
||||
"dual_path: A-role sent NAT tickle"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})();
|
||||
if let Err(e) = tickle_result {
|
||||
tracing::warn!(error = %e, "dual_path: A-role NAT tickle failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accept loop: retry if we get a stale/closed
|
||||
// connection from a previous call. Max 3 retries
|
||||
// to avoid spinning until the race timeout.
|
||||
@@ -270,8 +420,9 @@ pub async fn race(
|
||||
};
|
||||
let ep_for_fut = ep.clone();
|
||||
let _v6_ep_for_dial = ipv6_endpoint.clone();
|
||||
let dial_order = peer_candidates.dial_order();
|
||||
let dial_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||
let sni = call_sni.clone();
|
||||
let diags = diags_collector.clone();
|
||||
direct_fut = Box::pin(async move {
|
||||
if dial_order.is_empty() {
|
||||
// No candidates — the race reduces to
|
||||
@@ -300,24 +451,47 @@ pub async fn race(
|
||||
// Re-enable once IPv6 datagram delivery is
|
||||
// verified on target networks.
|
||||
if candidate.is_ipv6() {
|
||||
tracing::debug!(
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
"dual_path: skipping IPv6 candidate (disabled)"
|
||||
);
|
||||
if let Ok(mut d) = diags.lock() {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: candidate.to_string(),
|
||||
result: "skipped:ipv6".into(),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let ep = ep_for_fut.clone();
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let sni = sni.clone();
|
||||
let diags_inner = diags.clone();
|
||||
set.spawn(async move {
|
||||
let result = wzp_transport::connect(
|
||||
&ep,
|
||||
candidate,
|
||||
&sni,
|
||||
client_cfg,
|
||||
)
|
||||
.await;
|
||||
let start = std::time::Instant::now();
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
"dual_path: dialing candidate"
|
||||
);
|
||||
let result =
|
||||
wzp_transport::connect(&ep, candidate, &sni, client_cfg).await;
|
||||
let elapsed = start.elapsed().as_millis() as u32;
|
||||
let diag_result = match &result {
|
||||
Ok(_) => "ok".to_string(),
|
||||
Err(e) => format!("error:{e}"),
|
||||
};
|
||||
if let Ok(mut d) = diags_inner.lock() {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: candidate.to_string(),
|
||||
result: diag_result,
|
||||
elapsed_ms: Some(elapsed),
|
||||
});
|
||||
}
|
||||
(idx, candidate, result)
|
||||
});
|
||||
}
|
||||
@@ -346,7 +520,7 @@ pub async fn race(
|
||||
return Ok(QuinnTransport::new(conn));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
error = %e,
|
||||
@@ -423,16 +597,17 @@ pub async fn race(
|
||||
// RaceResult with both transports (when available) and uses the
|
||||
// Phase 6 MediaPathReport exchange to decide which one to
|
||||
// actually use for media.
|
||||
let smart_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||
tracing::info!(
|
||||
?role,
|
||||
candidates = ?peer_candidates.dial_order(),
|
||||
raw_candidates = ?peer_candidates.dial_order(),
|
||||
filtered_candidates = ?smart_order,
|
||||
?own_reflexive,
|
||||
%relay_addr,
|
||||
"dual_path: racing direct vs relay"
|
||||
);
|
||||
|
||||
let mut direct_task = tokio::spawn(
|
||||
tokio::time::timeout(Duration::from_secs(2), direct_fut),
|
||||
);
|
||||
let mut direct_task = tokio::spawn(tokio::time::timeout(Duration::from_secs(4), direct_fut));
|
||||
let mut relay_task = tokio::spawn(async move {
|
||||
// Keep the 500ms head start so direct has a chance
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
@@ -464,9 +639,25 @@ pub async fn race(
|
||||
local_winner = WinningPath::Relay; // direct failed → relay is our only hope
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
tracing::warn!("dual_path: direct timed out (2s)");
|
||||
tracing::warn!("dual_path: direct timed out (4s)");
|
||||
direct_result = Some(Err(anyhow::anyhow!("direct timeout")));
|
||||
local_winner = WinningPath::Relay;
|
||||
// Record timeout diag for candidates that were
|
||||
// still in-flight when the timeout fired.
|
||||
if let Ok(mut d) = diags_collector.lock() {
|
||||
let recorded_indices: std::collections::HashSet<usize> =
|
||||
d.iter().map(|diag| diag.index).collect();
|
||||
for (idx, addr) in smart_order.iter().enumerate() {
|
||||
if !recorded_indices.contains(&idx) {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: addr.to_string(),
|
||||
result: "timeout:4s".into(),
|
||||
elapsed_ms: Some(4000),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "dual_path: direct task panicked");
|
||||
@@ -505,16 +696,43 @@ pub async fn race(
|
||||
// If it doesn't, we still proceed with just the winner.
|
||||
if direct_result.is_none() {
|
||||
match tokio::time::timeout(Duration::from_secs(1), direct_task).await {
|
||||
Ok(Ok(Ok(Ok(t)))) => { direct_result = Some(Ok(t)); }
|
||||
Ok(Ok(Ok(Err(e)))) => { direct_result = Some(Err(anyhow::anyhow!("{e}"))); }
|
||||
_ => { direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period"))); }
|
||||
Ok(Ok(Ok(Ok(t)))) => {
|
||||
direct_result = Some(Ok(t));
|
||||
}
|
||||
Ok(Ok(Ok(Err(e)))) => {
|
||||
direct_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
}
|
||||
_ => {
|
||||
direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period")));
|
||||
// Fill timeout diags for candidates that never reported.
|
||||
if let Ok(mut d) = diags_collector.lock() {
|
||||
let recorded: std::collections::HashSet<usize> =
|
||||
d.iter().map(|diag| diag.index).collect();
|
||||
for (idx, addr) in smart_order.iter().enumerate() {
|
||||
if !recorded.contains(&idx) {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: addr.to_string(),
|
||||
result: "timeout:grace".into(),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if relay_result.is_none() {
|
||||
match tokio::time::timeout(Duration::from_secs(1), relay_task).await {
|
||||
Ok(Ok(Ok(Ok(t)))) => { relay_result = Some(Ok(t)); }
|
||||
Ok(Ok(Ok(Err(e)))) => { relay_result = Some(Err(anyhow::anyhow!("{e}"))); }
|
||||
_ => { relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period"))); }
|
||||
Ok(Ok(Ok(Ok(t)))) => {
|
||||
relay_result = Some(Ok(t));
|
||||
}
|
||||
Ok(Ok(Ok(Err(e)))) => {
|
||||
relay_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
}
|
||||
_ => {
|
||||
relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,18 +747,230 @@ pub async fn race(
|
||||
);
|
||||
|
||||
if !direct_ok && !relay_ok {
|
||||
return Err(anyhow::anyhow!("both paths failed: no media transport available"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"both paths failed: no media transport available"
|
||||
));
|
||||
}
|
||||
|
||||
let _ = (direct_ep, relay_ep, ipv6_endpoint);
|
||||
|
||||
let candidate_diags = diags_collector
|
||||
.lock()
|
||||
.map(|d| d.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(RaceResult {
|
||||
direct_transport: direct_result
|
||||
.and_then(|r| r.ok())
|
||||
.map(|t| Arc::new(t)),
|
||||
relay_transport: relay_result
|
||||
.and_then(|r| r.ok())
|
||||
.map(|t| Arc::new(t)),
|
||||
direct_transport: direct_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||
relay_transport: relay_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||
local_winner,
|
||||
candidate_diags,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_all_types() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"192.168.1.10:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
// Order: local first, then mapped, then reflexive
|
||||
assert_eq!(order.len(), 4);
|
||||
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(order[1], "10.0.0.5:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(
|
||||
order[2],
|
||||
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
assert_eq!(order[3], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_no_mapped() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec!["192.168.1.10:4433".parse().unwrap()],
|
||||
mapped: None,
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 2);
|
||||
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(order[1], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_only_mapped() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: None,
|
||||
local: vec![],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(
|
||||
order[0],
|
||||
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_dedup_mapped_equals_reflexive() {
|
||||
let addr: SocketAddr = "203.0.113.5:4433".parse().unwrap();
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some(addr),
|
||||
local: vec![],
|
||||
mapped: Some(addr), // same as reflexive
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
// Should be deduped to 1
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(order[0], addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_dedup_mapped_in_local() {
|
||||
let addr: SocketAddr = "192.168.1.10:4433".parse().unwrap();
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: None,
|
||||
local: vec![addr],
|
||||
mapped: Some(addr), // same as a local addr
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(order[0], addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_is_empty() {
|
||||
let empty = PeerCandidates::default();
|
||||
assert!(empty.is_empty());
|
||||
|
||||
let with_reflexive = PeerCandidates {
|
||||
reflexive: Some("1.2.3.4:5".parse().unwrap()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_reflexive.is_empty());
|
||||
|
||||
let with_local = PeerCandidates {
|
||||
local: vec!["10.0.0.1:5".parse().unwrap()],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_local.is_empty());
|
||||
|
||||
let with_mapped = PeerCandidates {
|
||||
mapped: Some("1.2.3.4:5".parse().unwrap()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_mapped.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_empty_dial_order() {
|
||||
let empty = PeerCandidates::default();
|
||||
assert!(empty.dial_order().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn winning_path_debug() {
|
||||
// Just verify Debug impl doesn't panic
|
||||
let _ = format!("{:?}", WinningPath::Direct);
|
||||
let _ = format!("{:?}", WinningPath::Relay);
|
||||
}
|
||||
|
||||
// ── smart_dial_order tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_same_network_includes_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"192.168.1.10:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
let own: SocketAddr = "203.0.113.5:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
// Same public IP → LAN candidates included
|
||||
assert!(order.contains(&"192.168.1.10:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"203.0.113.5:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_different_network_strips_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"172.16.81.126:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
// Different public IP → LAN candidates stripped
|
||||
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
assert!(!order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||
// Reflexive still included
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_strips_ipv6() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"[2a0d:3344:692c::1]:4433".parse().unwrap(),
|
||||
"172.16.81.126:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
// Same network, but IPv6 should be stripped
|
||||
let own: SocketAddr = "150.228.49.65:5555".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert!(!order.iter().any(|a| a.is_ipv6()));
|
||||
assert!(order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_no_own_reflexive_strips_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec!["172.16.81.126:4433".parse().unwrap()],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
// No own reflexive → can't determine same network → strip LAN
|
||||
let order = candidates.smart_dial_order(None);
|
||||
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_mapped_always_included() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert_eq!(order.len(), 2); // mapped + reflexive
|
||||
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ pub async fn run_echo_test(
|
||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
total_packets_received += 1;
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -184,7 +184,8 @@ pub async fn run_echo_test(
|
||||
let time_offset = start.elapsed().as_secs_f64();
|
||||
|
||||
// Compare sent vs received for this window
|
||||
let sent_start = (window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
||||
let sent_start =
|
||||
(window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
||||
let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES);
|
||||
let sent_window = if sent_end <= sent_pcm.len() {
|
||||
&sent_pcm[sent_start..sent_end]
|
||||
@@ -192,7 +193,9 @@ pub async fn run_echo_test(
|
||||
&sent_pcm[sent_start..]
|
||||
};
|
||||
|
||||
let recv_start = recv_pcm.len().saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
||||
let recv_start = recv_pcm
|
||||
.len()
|
||||
.saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
||||
let recv_window = &recv_pcm[recv_start..];
|
||||
|
||||
let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0);
|
||||
@@ -256,7 +259,7 @@ pub async fn run_echo_test(
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
total_packets_received += 1;
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
decoder.decode_next(&mut pcm_buf);
|
||||
@@ -310,8 +313,14 @@ pub fn print_report(result: &EchoTestResult) {
|
||||
let status = if w.is_silent { " !" } else { " " };
|
||||
println!(
|
||||
"│ {:>3}{} │ {:>5.1}s │ {:>4} │ {:>4} │ {:>5.1}% │ {:>5.1} │ {:.3} │",
|
||||
w.index, status, w.time_offset_secs, w.frames_sent, w.frames_received,
|
||||
w.loss_pct, w.snr_db, w.correlation
|
||||
w.index,
|
||||
status,
|
||||
w.time_offset_secs,
|
||||
w.frames_sent,
|
||||
w.frames_received,
|
||||
w.loss_pct,
|
||||
w.snr_db,
|
||||
w.correlation
|
||||
);
|
||||
}
|
||||
println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘");
|
||||
@@ -321,18 +330,28 @@ pub fn print_report(result: &EchoTestResult) {
|
||||
let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec();
|
||||
let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec();
|
||||
|
||||
let avg_loss_first = first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_loss_second = second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_corr_first = first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_corr_second = second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_loss_first =
|
||||
first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_loss_second =
|
||||
second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_corr_first =
|
||||
first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_corr_second =
|
||||
second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
||||
|
||||
println!();
|
||||
if avg_loss_second > avg_loss_first + 5.0 {
|
||||
println!("WARNING: Quality degradation detected!");
|
||||
println!(" Loss increased from {:.1}% to {:.1}% over time", avg_loss_first, avg_loss_second);
|
||||
println!(
|
||||
" Loss increased from {:.1}% to {:.1}% over time",
|
||||
avg_loss_first, avg_loss_second
|
||||
);
|
||||
}
|
||||
if avg_corr_second < avg_corr_first - 0.1 {
|
||||
println!("WARNING: Signal correlation dropped from {:.3} to {:.3}", avg_corr_first, avg_corr_second);
|
||||
println!(
|
||||
"WARNING: Signal correlation dropped from {:.3} to {:.3}",
|
||||
avg_corr_first, avg_corr_second
|
||||
);
|
||||
}
|
||||
if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 {
|
||||
println!("Quality is STABLE over the test duration.");
|
||||
|
||||
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
//! `EncryptingTransport` — wraps any `MediaTransport` with a `CryptoSession`.
|
||||
//!
|
||||
//! All outbound `send_media` calls encrypt the payload before handing off to
|
||||
//! the inner transport; all inbound `recv_media` calls decrypt after receiving.
|
||||
//! Signal, quality, and close are forwarded unchanged.
|
||||
//!
|
||||
//! The quality report travels in plaintext so the relay can make QoS decisions
|
||||
//! without being able to decrypt media content.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use wzp_proto::{
|
||||
CryptoSession, MediaHeader, MediaPacket, MediaTransport, PathQuality, SignalMessage,
|
||||
TransportError,
|
||||
};
|
||||
|
||||
/// Wraps a `MediaTransport` and applies AEAD encryption/decryption to media payloads.
|
||||
pub struct EncryptingTransport {
|
||||
inner: Arc<dyn MediaTransport>,
|
||||
session: Mutex<Box<dyn CryptoSession>>,
|
||||
}
|
||||
|
||||
impl EncryptingTransport {
|
||||
pub fn new(inner: Arc<dyn MediaTransport>, session: Box<dyn CryptoSession>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
session: Mutex::new(session),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MediaTransport for EncryptingTransport {
|
||||
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
packet.header.write_to(&mut header_bytes);
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap()
|
||||
.encrypt(&header_bytes, &packet.payload, &mut ciphertext)
|
||||
.map_err(|e| TransportError::Internal(format!("encrypt: {e}")))?;
|
||||
|
||||
let encrypted = MediaPacket {
|
||||
header: packet.header,
|
||||
payload: Bytes::from(ciphertext),
|
||||
quality_report: packet.quality_report.clone(),
|
||||
};
|
||||
self.inner.send_media(&encrypted).await
|
||||
}
|
||||
|
||||
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||
let packet = match self.inner.recv_media().await? {
|
||||
Some(p) => p,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
packet.header.write_to(&mut header_bytes);
|
||||
|
||||
let mut plaintext = Vec::new();
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap()
|
||||
.decrypt(&header_bytes, &packet.payload, &mut plaintext)
|
||||
.map_err(|e| TransportError::Internal(format!("decrypt: {e}")))?;
|
||||
|
||||
Ok(Some(MediaPacket {
|
||||
header: packet.header,
|
||||
payload: Bytes::from(plaintext),
|
||||
quality_report: packet.quality_report,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn send_signal(&self, msg: &SignalMessage) -> Result<(), TransportError> {
|
||||
self.inner.send_signal(msg).await
|
||||
}
|
||||
|
||||
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||
self.inner.recv_signal().await
|
||||
}
|
||||
|
||||
fn path_quality(&self) -> PathQuality {
|
||||
self.inner.path_quality()
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), TransportError> {
|
||||
self.inner.close().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use wzp_crypto::ChaChaSession;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
struct LoopbackTransport {
|
||||
sent: StdMutex<Vec<MediaPacket>>,
|
||||
}
|
||||
|
||||
impl LoopbackTransport {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
sent: StdMutex::new(Vec::new()),
|
||||
})
|
||||
}
|
||||
fn take_sent(&self) -> Vec<MediaPacket> {
|
||||
self.sent.lock().unwrap().drain(..).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MediaTransport for LoopbackTransport {
|
||||
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||
self.sent.lock().unwrap().push(packet.clone());
|
||||
Ok(())
|
||||
}
|
||||
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||
Ok(None)
|
||||
}
|
||||
async fn send_signal(&self, _msg: &SignalMessage) -> Result<(), TransportError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||
Ok(None)
|
||||
}
|
||||
fn path_quality(&self) -> PathQuality {
|
||||
PathQuality::default()
|
||||
}
|
||||
async fn close(&self) -> Result<(), TransportError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_header(seq: u32) -> MediaHeader {
|
||||
MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn payload_is_encrypted_on_wire() {
|
||||
let key = [0x42u8; 32];
|
||||
let session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||
let loopback = LoopbackTransport::new();
|
||||
let enc = EncryptingTransport::new(loopback.clone(), session);
|
||||
|
||||
let header = make_header(1);
|
||||
let plaintext = b"secret audio frame";
|
||||
let pkt = MediaPacket {
|
||||
header,
|
||||
payload: Bytes::from_static(plaintext),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
enc.send_media(&pkt).await.unwrap();
|
||||
|
||||
let sent = loopback.take_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert_eq!(sent[0].header, header, "header must be preserved");
|
||||
assert_ne!(
|
||||
sent[0].payload.as_ref(),
|
||||
plaintext.as_ref(),
|
||||
"plaintext must not appear on wire"
|
||||
);
|
||||
// Ciphertext is longer by exactly the AEAD tag (16 bytes)
|
||||
assert_eq!(sent[0].payload.len(), plaintext.len() + 16);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn encrypt_then_decrypt_roundtrip() {
|
||||
let key = [0x42u8; 32];
|
||||
let send_session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||
let mut recv_session = ChaChaSession::new(key);
|
||||
|
||||
let loopback = LoopbackTransport::new();
|
||||
let enc = EncryptingTransport::new(loopback.clone(), send_session);
|
||||
|
||||
let header = make_header(5);
|
||||
let plaintext = b"hello encrypted world";
|
||||
let pkt = MediaPacket {
|
||||
header,
|
||||
payload: Bytes::from_static(plaintext),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
enc.send_media(&pkt).await.unwrap();
|
||||
|
||||
let sent = loopback.take_sent();
|
||||
let wire_pkt = &sent[0];
|
||||
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
let mut decrypted = Vec::new();
|
||||
recv_session
|
||||
.decrypt(&header_bytes, &wire_pkt.payload, &mut decrypted)
|
||||
.expect("decrypt should succeed with matching key");
|
||||
assert_eq!(&decrypted[..], plaintext);
|
||||
}
|
||||
}
|
||||
@@ -99,14 +99,15 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry)
|
||||
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
|
||||
SignalMessage::AuthToken { .. } => CallSignalType::Offer,
|
||||
SignalMessage::Hold => CallSignalType::Hold,
|
||||
SignalMessage::Unhold => CallSignalType::Unhold,
|
||||
SignalMessage::Mute => CallSignalType::Mute,
|
||||
SignalMessage::Unmute => CallSignalType::Unmute,
|
||||
SignalMessage::Hold { .. } => CallSignalType::Hold,
|
||||
SignalMessage::Unhold { .. } => CallSignalType::Unhold,
|
||||
SignalMessage::Mute { .. } => CallSignalType::Mute,
|
||||
SignalMessage::Unmute { .. } => CallSignalType::Unmute,
|
||||
SignalMessage::Transfer { .. } => CallSignalType::Transfer,
|
||||
SignalMessage::TransferAck => CallSignalType::Offer, // reuse
|
||||
SignalMessage::TransferAck { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::TransportFeedback { .. } => CallSignalType::Offer, // reuse (BWE)
|
||||
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
|
||||
@@ -118,20 +119,31 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer,
|
||||
SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only
|
||||
SignalMessage::CallRinging { .. } => CallSignalType::Ringing,
|
||||
SignalMessage::RegisterPresence { .. }
|
||||
| SignalMessage::RegisterPresenceAck { .. } => CallSignalType::Offer, // relay-only
|
||||
SignalMessage::RegisterPresence { .. } | SignalMessage::RegisterPresenceAck { .. } => {
|
||||
CallSignalType::Offer
|
||||
} // relay-only
|
||||
// NAT reflection is a client↔relay control exchange that
|
||||
// never crosses the featherChat bridge — if it ever reaches
|
||||
// this mapper something is wrong, but we still have to give
|
||||
// an answer. "Offer" is the generic catch-all.
|
||||
SignalMessage::Reflect
|
||||
| SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
|
||||
SignalMessage::Reflect | SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
|
||||
// Phase 4 cross-relay forwarding envelope — strictly a
|
||||
// relay-to-relay message, never rides the featherChat
|
||||
// bridge. Catch-all mapping for completeness.
|
||||
SignalMessage::FederatedSignalForward { .. } => CallSignalType::Offer,
|
||||
SignalMessage::MediaPathReport { .. } => CallSignalType::Offer, // control-plane
|
||||
SignalMessage::CandidateUpdate { .. } => CallSignalType::IceCandidate, // mid-call re-gather
|
||||
SignalMessage::HardNatProbe { .. } => CallSignalType::IceCandidate, // hard NAT coordination
|
||||
SignalMessage::HardNatBirthdayStart { .. } => CallSignalType::IceCandidate, // birthday attack
|
||||
SignalMessage::UpgradeProposal { .. }
|
||||
| SignalMessage::UpgradeResponse { .. }
|
||||
| SignalMessage::UpgradeConfirm { .. }
|
||||
| SignalMessage::QualityCapability { .. } => CallSignalType::Offer, // quality negotiation
|
||||
SignalMessage::PresenceList { .. } => CallSignalType::Offer, // lobby presence
|
||||
SignalMessage::QualityDirective { .. } => CallSignalType::Offer, // relay-initiated
|
||||
SignalMessage::Nack { .. }
|
||||
| SignalMessage::PictureLossIndication { .. }
|
||||
| SignalMessage::SetPriorityMode { .. } => CallSignalType::Offer, // relay-initiated (video loss recovery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,15 +151,20 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::default_signal_version;
|
||||
|
||||
#[test]
|
||||
fn payload_roundtrip() {
|
||||
let signal = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![],
|
||||
};
|
||||
|
||||
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
|
||||
@@ -161,29 +178,53 @@ mod tests {
|
||||
#[test]
|
||||
fn signal_type_mapping() {
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
supported_profiles: vec![],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![],
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
|
||||
|
||||
let hangup = SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&hangup), CallSignalType::Hangup));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&hangup),
|
||||
CallSignalType::Hangup
|
||||
));
|
||||
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Hold), CallSignalType::Hold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unhold), CallSignalType::Unhold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Mute), CallSignalType::Mute));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unmute), CallSignalType::Unmute));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Hold { version: default_signal_version() }),
|
||||
CallSignalType::Hold
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Unhold { version: default_signal_version() }),
|
||||
CallSignalType::Unhold
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Mute { version: default_signal_version() }),
|
||||
CallSignalType::Mute
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Unmute { version: default_signal_version() }),
|
||||
CallSignalType::Unmute
|
||||
));
|
||||
|
||||
let transfer = SignalMessage::Transfer {
|
||||
version: default_signal_version(),
|
||||
target_fingerprint: "abc".to_string(),
|
||||
relay_addr: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&transfer), CallSignalType::Transfer));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&transfer),
|
||||
CallSignalType::Transfer
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,62 @@
|
||||
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
|
||||
|
||||
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
use wzp_proto::{
|
||||
CodecId, HangupReason, MediaTransport, QualityProfile, SignalMessage, default_signal_version,
|
||||
};
|
||||
|
||||
const SUPPORTED_VIDEO_CODECS: &[CodecId] = &[CodecId::H264Baseline];
|
||||
|
||||
/// Result of a successful client-side handshake.
|
||||
pub struct HandshakeResult {
|
||||
pub session: Box<dyn CryptoSession>,
|
||||
/// Video codec agreed with the relay. `None` if peer is audio-only.
|
||||
pub video_codec: Option<CodecId>,
|
||||
}
|
||||
|
||||
/// Errors that can occur during the client-side cryptographic handshake.
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeError {
|
||||
ConnectionClosed,
|
||||
ProtocolVersionMismatch { server_supported: Vec<u8> },
|
||||
UnexpectedSignal(&'static str),
|
||||
SignatureVerificationFailed,
|
||||
KeyDerivation(String),
|
||||
Transport(wzp_proto::TransportError),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HandshakeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => write!(f, "connection closed before receiving CallAnswer"),
|
||||
Self::ProtocolVersionMismatch { server_supported } => {
|
||||
write!(
|
||||
f,
|
||||
"protocol version mismatch: server supports {server_supported:?}"
|
||||
)
|
||||
}
|
||||
Self::UnexpectedSignal(expected) => write!(f, "expected CallAnswer, got {expected}"),
|
||||
Self::SignatureVerificationFailed => write!(f, "callee signature verification failed"),
|
||||
Self::KeyDerivation(msg) => write!(f, "key derivation failed: {msg}"),
|
||||
Self::Transport(e) => write!(f, "transport error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for HandshakeError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Transport(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wzp_proto::TransportError> for HandshakeError {
|
||||
fn from(e: wzp_proto::TransportError) -> Self {
|
||||
Self::Transport(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the client (caller) side of the cryptographic handshake.
|
||||
///
|
||||
@@ -18,7 +73,17 @@ pub async fn perform_handshake(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
alias: Option<&str>,
|
||||
) -> Result<Box<dyn CryptoSession>, anyhow::Error> {
|
||||
) -> Result<HandshakeResult, HandshakeError> {
|
||||
perform_handshake_with_video_codecs(transport, seed, alias, SUPPORTED_VIDEO_CODECS.to_vec())
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn perform_handshake_with_video_codecs(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
alias: Option<&str>,
|
||||
video_codecs: Vec<CodecId>,
|
||||
) -> Result<HandshakeResult, HandshakeError> {
|
||||
// 1. Create key exchange from identity seed
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
|
||||
let identity_pub = kx.identity_public_key();
|
||||
@@ -34,6 +99,7 @@ pub async fn perform_handshake(
|
||||
|
||||
// 4. Send CallOffer
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
@@ -46,43 +112,66 @@ pub async fn perform_handshake(
|
||||
QualityProfile::CATASTROPHIC,
|
||||
],
|
||||
alias: alias.map(|s| s.to_string()),
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs,
|
||||
};
|
||||
transport.send_signal(&offer).await?;
|
||||
transport
|
||||
.send_signal(&offer)
|
||||
.await
|
||||
.map_err(HandshakeError::Transport)?;
|
||||
|
||||
// 5. Wait for CallAnswer
|
||||
let answer = transport
|
||||
.recv_signal()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallAnswer"))?;
|
||||
// 5. Wait for CallAnswer — 10s timeout guards against relay not responding.
|
||||
let answer = tokio::time::timeout(std::time::Duration::from_secs(10), transport.recv_signal())
|
||||
.await
|
||||
.map_err(|_| HandshakeError::Transport(wzp_proto::TransportError::Timeout { ms: 10_000 }))?
|
||||
.map_err(HandshakeError::Transport)?
|
||||
.ok_or(HandshakeError::ConnectionClosed)?;
|
||||
|
||||
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) = match answer
|
||||
{
|
||||
SignalMessage::CallAnswer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
} => (identity_pub, ephemeral_pub, signature, chosen_profile),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallAnswer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
))
|
||||
}
|
||||
};
|
||||
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile, video_codec) =
|
||||
match answer {
|
||||
SignalMessage::CallAnswer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
video_codec,
|
||||
..
|
||||
} => (
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
video_codec,
|
||||
),
|
||||
SignalMessage::Hangup {
|
||||
reason: HangupReason::ProtocolVersionMismatch { server_supported },
|
||||
..
|
||||
} => {
|
||||
return Err(HandshakeError::ProtocolVersionMismatch { server_supported });
|
||||
}
|
||||
_ => {
|
||||
return Err(HandshakeError::UnexpectedSignal("CallAnswer"));
|
||||
}
|
||||
};
|
||||
|
||||
// 6. Verify callee's signature over (ephemeral_pub || "call-answer")
|
||||
let mut verify_data = Vec::with_capacity(32 + 11);
|
||||
verify_data.extend_from_slice(&callee_ephemeral_pub);
|
||||
verify_data.extend_from_slice(b"call-answer");
|
||||
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
|
||||
return Err(anyhow::anyhow!("callee signature verification failed"));
|
||||
return Err(HandshakeError::SignatureVerificationFailed);
|
||||
}
|
||||
|
||||
// 7. Derive session
|
||||
let session = kx.derive_session(&callee_ephemeral_pub)?;
|
||||
let session = kx
|
||||
.derive_session(&callee_ephemeral_pub)
|
||||
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
|
||||
|
||||
Ok(session)
|
||||
Ok(HandshakeResult {
|
||||
session,
|
||||
video_codec,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -104,4 +193,34 @@ mod tests {
|
||||
&sig,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_result_carries_video_codec() {
|
||||
// Verify that HandshakeResult has both fields accessible and that
|
||||
// None is the correct default for audio-only peers.
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(&[0x55; 32]);
|
||||
kx.generate_ephemeral();
|
||||
let session = kx.derive_session(&[0u8; 32]).unwrap();
|
||||
let hs = HandshakeResult {
|
||||
session,
|
||||
video_codec: None,
|
||||
};
|
||||
assert!(hs.video_codec.is_none());
|
||||
|
||||
let mut kx2 = WarzoneKeyExchange::from_identity_seed(&[0x66; 32]);
|
||||
kx2.generate_ephemeral();
|
||||
let session2 = kx2.derive_session(&[0u8; 32]).unwrap();
|
||||
let hs2 = HandshakeResult {
|
||||
session: session2,
|
||||
video_codec: Some(CodecId::H264Baseline),
|
||||
};
|
||||
assert_eq!(hs2.video_codec, Some(CodecId::H264Baseline));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn offer_contains_h264_only() {
|
||||
// Keep room video on the common denominator until Android AV1/HEVC
|
||||
// send paths are proven in-device.
|
||||
assert_eq!(SUPPORTED_VIDEO_CODECS, &[CodecId::H264Baseline]);
|
||||
}
|
||||
}
|
||||
|
||||
440
crates/wzp-client/src/ice_agent.rs
Normal file
440
crates/wzp-client/src/ice_agent.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Phase 8 (Tailscale-inspired): ICE agent for candidate lifecycle
|
||||
//! management and mid-call re-gathering.
|
||||
//!
|
||||
//! The `IceAgent` owns the state of all candidate discovery
|
||||
//! mechanisms (STUN, port mapping, host candidates) and provides:
|
||||
//!
|
||||
//! - `gather()`: initial candidate gathering during call setup
|
||||
//! - `re_gather()`: triggered on network change, produces a
|
||||
//! `CandidateUpdate` to send to the peer
|
||||
//! - `apply_peer_update()`: processes peer's candidate updates
|
||||
//!
|
||||
//! This is NOT a full ICE agent (RFC 8445). It's the Tailscale-style
|
||||
//! "gather all candidates, race them all in parallel, pick the
|
||||
//! winner" approach, adapted for QUIC transport.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
use crate::dual_path::PeerCandidates;
|
||||
use crate::portmap;
|
||||
use crate::reflect;
|
||||
use crate::stun;
|
||||
|
||||
/// All candidates gathered for the local side.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CandidateSet {
|
||||
/// STUN-discovered server-reflexive address.
|
||||
pub reflexive: Option<SocketAddr>,
|
||||
/// LAN host candidates from local interfaces.
|
||||
pub local: Vec<SocketAddr>,
|
||||
/// Port-mapped address from NAT-PMP/PCP/UPnP.
|
||||
pub mapped: Option<SocketAddr>,
|
||||
/// Generation counter (monotonically increasing per call).
|
||||
pub generation: u32,
|
||||
}
|
||||
|
||||
/// Configuration for the ICE agent.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IceAgentConfig {
|
||||
/// STUN servers to use for reflexive discovery.
|
||||
pub stun_config: stun::StunConfig,
|
||||
/// Whether to attempt port mapping.
|
||||
pub enable_portmap: bool,
|
||||
/// Timeout for each discovery mechanism.
|
||||
pub gather_timeout: Duration,
|
||||
/// The QUIC endpoint's local port (for host candidate pairing).
|
||||
pub local_v4_port: u16,
|
||||
/// Optional IPv6 port.
|
||||
pub local_v6_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl Default for IceAgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stun_config: stun::StunConfig::default(),
|
||||
enable_portmap: true,
|
||||
gather_timeout: Duration::from_secs(3),
|
||||
local_v4_port: 0,
|
||||
local_v6_port: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ICE agent managing candidate lifecycle.
|
||||
pub struct IceAgent {
|
||||
config: IceAgentConfig,
|
||||
generation: AtomicU32,
|
||||
call_id: String,
|
||||
/// Last-seen peer generation (to filter stale updates).
|
||||
peer_generation: AtomicU32,
|
||||
}
|
||||
|
||||
impl IceAgent {
|
||||
pub fn new(call_id: String, config: IceAgentConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
generation: AtomicU32::new(0),
|
||||
call_id,
|
||||
peer_generation: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initial candidate gathering. Runs all discovery mechanisms
|
||||
/// in parallel and returns the full candidate set.
|
||||
pub async fn gather(&self) -> CandidateSet {
|
||||
let generation = self.generation.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// Run STUN + port mapping + host candidates in parallel.
|
||||
let stun_fut = stun::discover_reflexive(&self.config.stun_config);
|
||||
let portmap_fut = async {
|
||||
if self.config.enable_portmap && self.config.local_v4_port > 0 {
|
||||
portmap::acquire_port_mapping(self.config.local_v4_port, None)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (stun_result, portmap_result) = tokio::join!(
|
||||
tokio::time::timeout(self.config.gather_timeout, stun_fut),
|
||||
tokio::time::timeout(self.config.gather_timeout, portmap_fut),
|
||||
);
|
||||
|
||||
let reflexive = stun_result.ok().and_then(|r| r.ok());
|
||||
let mapped = portmap_result.ok().flatten().map(|m| m.external_addr);
|
||||
let local =
|
||||
reflect::local_host_candidates(self.config.local_v4_port, self.config.local_v6_port);
|
||||
|
||||
tracing::info!(
|
||||
generation,
|
||||
reflexive = ?reflexive,
|
||||
mapped = ?mapped,
|
||||
local_count = local.len(),
|
||||
"ice_agent: gathered candidates"
|
||||
);
|
||||
|
||||
CandidateSet {
|
||||
reflexive,
|
||||
local,
|
||||
mapped,
|
||||
generation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Re-gather candidates after a network change. Increments the
|
||||
/// generation counter and returns a `CandidateUpdate` signal
|
||||
/// message to send to the peer.
|
||||
pub async fn re_gather(&self) -> (CandidateSet, SignalMessage) {
|
||||
let candidates = self.gather().await;
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: self.call_id.clone(),
|
||||
reflexive_addr: candidates.reflexive.map(|a| a.to_string()),
|
||||
local_addrs: candidates.local.iter().map(|a| a.to_string()).collect(),
|
||||
mapped_addr: candidates.mapped.map(|a| a.to_string()),
|
||||
generation: candidates.generation,
|
||||
};
|
||||
|
||||
(candidates, update)
|
||||
}
|
||||
|
||||
/// Process a peer's candidate update. Returns `Some(PeerCandidates)`
|
||||
/// if the update is newer than the last-seen generation, `None`
|
||||
/// if it's stale.
|
||||
pub fn apply_peer_update(&self, update: &SignalMessage) -> Option<PeerCandidates> {
|
||||
let (reflexive_addr, local_addrs, mapped_addr, generation) = match update {
|
||||
SignalMessage::CandidateUpdate {
|
||||
reflexive_addr,
|
||||
local_addrs,
|
||||
mapped_addr,
|
||||
generation,
|
||||
..
|
||||
} => (reflexive_addr, local_addrs, mapped_addr, *generation),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
// Only accept if newer than last-seen generation.
|
||||
let prev = self.peer_generation.fetch_max(generation, Ordering::AcqRel);
|
||||
if generation <= prev {
|
||||
tracing::debug!(
|
||||
generation,
|
||||
prev,
|
||||
"ice_agent: ignoring stale CandidateUpdate"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let reflexive = reflexive_addr.as_deref().and_then(|s| s.parse().ok());
|
||||
let local: Vec<SocketAddr> = local_addrs.iter().filter_map(|s| s.parse().ok()).collect();
|
||||
let mapped = mapped_addr.as_deref().and_then(|s| s.parse().ok());
|
||||
|
||||
tracing::info!(
|
||||
generation,
|
||||
reflexive = ?reflexive,
|
||||
mapped = ?mapped,
|
||||
local_count = local.len(),
|
||||
"ice_agent: applied peer candidate update"
|
||||
);
|
||||
|
||||
Some(PeerCandidates {
|
||||
reflexive,
|
||||
local,
|
||||
mapped,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current generation counter.
|
||||
pub fn generation(&self) -> u32 {
|
||||
self.generation.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_rejects_stale() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
|
||||
// First update (gen=1) should succeed.
|
||||
let update1 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
local_addrs: vec!["192.168.1.10:4433".into()],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
let result = agent.apply_peer_update(&update1);
|
||||
assert!(result.is_some());
|
||||
let candidates = result.unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("203.0.113.5:4433".parse().unwrap())
|
||||
);
|
||||
assert_eq!(candidates.local.len(), 1);
|
||||
|
||||
// Same generation (gen=1) should be rejected.
|
||||
let update1b = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("198.51.100.9:4433".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
assert!(agent.apply_peer_update(&update1b).is_none());
|
||||
|
||||
// Older generation (gen=0) should be rejected.
|
||||
let update0 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("10.0.0.1:4433".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 0,
|
||||
};
|
||||
assert!(agent.apply_peer_update(&update0).is_none());
|
||||
|
||||
// Newer generation (gen=2) should succeed.
|
||||
let update2 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("198.51.100.9:5555".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: Some("203.0.113.5:12345".into()),
|
||||
generation: 2,
|
||||
};
|
||||
let result = agent.apply_peer_update(&update2);
|
||||
assert!(result.is_some());
|
||||
let candidates = result.unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("198.51.100.9:5555".parse().unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
candidates.mapped,
|
||||
Some("203.0.113.5:12345".parse().unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_wrong_signal_returns_none() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
let wrong = SignalMessage::Reflect;
|
||||
assert!(agent.apply_peer_update(&wrong).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generation_increments() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
assert_eq!(agent.generation(), 0);
|
||||
// Simulate what gather() does internally
|
||||
let g1 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||
assert_eq!(g1, 0);
|
||||
assert_eq!(agent.generation(), 1);
|
||||
let g2 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||
assert_eq!(g2, 1);
|
||||
assert_eq!(agent.generation(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_parses_all_fields() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
local_addrs: vec!["192.168.1.10:4433".into(), "10.0.0.5:4433".into()],
|
||||
mapped_addr: Some("198.51.100.42:12345".into()),
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("203.0.113.5:4433".parse().unwrap())
|
||||
);
|
||||
assert_eq!(candidates.local.len(), 2);
|
||||
assert_eq!(
|
||||
candidates.local[0],
|
||||
"192.168.1.10:4433".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
candidates.mapped,
|
||||
Some("198.51.100.42:12345".parse().unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_handles_empty_fields() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test".into(),
|
||||
reflexive_addr: None,
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert!(candidates.reflexive.is_none());
|
||||
assert!(candidates.local.is_empty());
|
||||
assert!(candidates.mapped.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_skips_unparseable_addrs() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test".into(),
|
||||
reflexive_addr: Some("not-an-addr".into()),
|
||||
local_addrs: vec![
|
||||
"192.168.1.10:4433".into(),
|
||||
"garbage".into(),
|
||||
"10.0.0.5:4433".into(),
|
||||
],
|
||||
mapped_addr: Some("also-bad".into()),
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert!(candidates.reflexive.is_none()); // unparseable
|
||||
assert_eq!(candidates.local.len(), 2); // garbage filtered
|
||||
assert!(candidates.mapped.is_none()); // unparseable
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = IceAgentConfig::default();
|
||||
assert!(cfg.enable_portmap);
|
||||
assert!(cfg.gather_timeout.as_secs() > 0);
|
||||
assert!(!cfg.stun_config.servers.is_empty());
|
||||
assert_eq!(cfg.local_v4_port, 0);
|
||||
assert!(cfg.local_v6_port.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn gather_returns_candidates_even_with_no_stun() {
|
||||
// With default config (port 0 = no portmap, STUN will timeout
|
||||
// quickly on loopback), gather should still return host candidates.
|
||||
let agent = IceAgent::new(
|
||||
"test".into(),
|
||||
IceAgentConfig {
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec![], // no servers = quick failure
|
||||
timeout: Duration::from_millis(100),
|
||||
},
|
||||
enable_portmap: false,
|
||||
gather_timeout: Duration::from_millis(200),
|
||||
local_v4_port: 12345,
|
||||
local_v6_port: None,
|
||||
},
|
||||
);
|
||||
|
||||
let candidates = agent.gather().await;
|
||||
assert_eq!(candidates.generation, 0);
|
||||
// Reflexive should be None (no STUN servers)
|
||||
assert!(candidates.reflexive.is_none());
|
||||
// Mapped should be None (portmap disabled)
|
||||
assert!(candidates.mapped.is_none());
|
||||
// Local candidates depend on the machine's interfaces
|
||||
// but gather() should not panic.
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn re_gather_produces_signal_message() {
|
||||
let agent = IceAgent::new(
|
||||
"call-42".into(),
|
||||
IceAgentConfig {
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec![],
|
||||
timeout: Duration::from_millis(50),
|
||||
},
|
||||
enable_portmap: false,
|
||||
gather_timeout: Duration::from_millis(100),
|
||||
local_v4_port: 4433,
|
||||
local_v6_port: None,
|
||||
},
|
||||
);
|
||||
|
||||
let (candidates, signal) = agent.re_gather().await;
|
||||
assert_eq!(candidates.generation, 0);
|
||||
|
||||
match signal {
|
||||
SignalMessage::CandidateUpdate {
|
||||
call_id,
|
||||
generation,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call-42");
|
||||
assert_eq!(generation, 0);
|
||||
}
|
||||
_ => panic!("expected CandidateUpdate"),
|
||||
}
|
||||
|
||||
// Second re_gather increments generation
|
||||
let (candidates2, signal2) = agent.re_gather().await;
|
||||
assert_eq!(candidates2.generation, 1);
|
||||
match signal2 {
|
||||
SignalMessage::CandidateUpdate { generation, .. } => {
|
||||
assert_eq!(generation, 1);
|
||||
}
|
||||
_ => panic!("expected CandidateUpdate"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,14 +27,21 @@ pub mod audio_wasapi;
|
||||
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
||||
pub mod audio_linux_aec;
|
||||
pub mod bench;
|
||||
pub mod birthday;
|
||||
pub mod call;
|
||||
pub mod encrypted_transport;
|
||||
pub mod drift_test;
|
||||
pub mod dual_path;
|
||||
pub mod echo_test;
|
||||
pub mod featherchat;
|
||||
pub mod handshake;
|
||||
pub mod dual_path;
|
||||
pub mod ice_agent;
|
||||
pub mod metrics;
|
||||
pub mod netcheck;
|
||||
pub mod portmap;
|
||||
pub mod reflect;
|
||||
pub mod relay_map;
|
||||
pub mod stun;
|
||||
pub mod sweep;
|
||||
|
||||
// AudioPlayback: three possible backends depending on feature flags.
|
||||
|
||||
@@ -178,7 +178,10 @@ mod tests {
|
||||
|
||||
// Immediate second write should be skipped (60s interval).
|
||||
let second = writer.maybe_write(&snap).unwrap();
|
||||
assert!(!second, "second write should be skipped — interval not elapsed");
|
||||
assert!(
|
||||
!second,
|
||||
"second write should be skipped — interval not elapsed"
|
||||
);
|
||||
|
||||
// Clean up.
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
537
crates/wzp-client/src/netcheck.rs
Normal file
537
crates/wzp-client/src/netcheck.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
//! Phase 8 (Tailscale-inspired): Comprehensive network diagnostic.
|
||||
//!
|
||||
//! Probes STUN servers, relay infrastructure, port mapping
|
||||
//! capabilities, IPv6 reachability, and NAT hairpinning in parallel
|
||||
//! to produce a `NetcheckReport` that captures the client's network
|
||||
//! environment at a point in time.
|
||||
//!
|
||||
//! Used for:
|
||||
//! - Troubleshooting connectivity issues
|
||||
//! - Automatic relay selection (Phase 5)
|
||||
//! - Pre-call NAT assessment
|
||||
//! - Quality prediction
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::portmap::{self, PortMapProtocol};
|
||||
use crate::reflect::{self, NatType};
|
||||
use crate::stun::{self, StunConfig};
|
||||
|
||||
/// Complete network diagnostic report.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NetcheckReport {
|
||||
/// NAT type classification (from combined STUN + relay probes).
|
||||
pub nat_type: NatType,
|
||||
/// Server-reflexive address (consensus from probes).
|
||||
pub reflexive_addr: Option<String>,
|
||||
/// Whether IPv4 connectivity is available.
|
||||
pub ipv4_reachable: bool,
|
||||
/// Whether IPv6 connectivity is available.
|
||||
pub ipv6_reachable: bool,
|
||||
/// Whether the NAT supports hairpinning (loopback to own
|
||||
/// reflexive address).
|
||||
pub hairpin_works: Option<bool>,
|
||||
/// Which port mapping protocol is available (if any).
|
||||
pub port_mapping: Option<PortMapProtocol>,
|
||||
/// Per-relay latency measurements.
|
||||
pub relay_latencies: Vec<RelayLatency>,
|
||||
/// Preferred relay (lowest latency).
|
||||
pub preferred_relay: Option<String>,
|
||||
/// STUN latency to first responding server (ms).
|
||||
pub stun_latency_ms: Option<u32>,
|
||||
/// Whether UPnP is available on the gateway.
|
||||
pub upnp_available: bool,
|
||||
/// Whether PCP is available on the gateway.
|
||||
pub pcp_available: bool,
|
||||
/// Whether NAT-PMP is available on the gateway.
|
||||
pub nat_pmp_available: bool,
|
||||
/// Default gateway address.
|
||||
pub gateway: Option<String>,
|
||||
/// Total time taken for the diagnostic (ms).
|
||||
pub duration_ms: u32,
|
||||
/// Individual STUN probe results.
|
||||
pub stun_probes: Vec<reflect::NatProbeResult>,
|
||||
/// NAT port allocation pattern (sequential vs random).
|
||||
pub port_allocation: Option<stun::PortAllocation>,
|
||||
}
|
||||
|
||||
/// Latency to a specific relay.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RelayLatency {
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub rtt_ms: Option<u32>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Configuration for the netcheck run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetcheckConfig {
|
||||
/// STUN servers to probe.
|
||||
pub stun_config: StunConfig,
|
||||
/// Relay servers to probe (name, address pairs).
|
||||
pub relays: Vec<(String, SocketAddr)>,
|
||||
/// Per-probe timeout.
|
||||
pub timeout: Duration,
|
||||
/// Whether to test port mapping.
|
||||
pub test_portmap: bool,
|
||||
/// Whether to test IPv6.
|
||||
pub test_ipv6: bool,
|
||||
/// Local port for port mapping test (0 = skip).
|
||||
pub local_port: u16,
|
||||
}
|
||||
|
||||
impl Default for NetcheckConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stun_config: StunConfig::default(),
|
||||
relays: Vec::new(),
|
||||
timeout: Duration::from_secs(5),
|
||||
test_portmap: true,
|
||||
test_ipv6: true,
|
||||
local_port: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a comprehensive network diagnostic.
|
||||
///
|
||||
/// Probes run in parallel for speed — the total time is bounded
|
||||
/// by the slowest individual probe, not the sum.
|
||||
pub async fn run_netcheck(config: &NetcheckConfig) -> NetcheckReport {
|
||||
let start = Instant::now();
|
||||
|
||||
// Run all probes in parallel.
|
||||
let stun_fut = stun::probe_stun_servers(&config.stun_config);
|
||||
let relay_fut = probe_relays(&config.relays, config.timeout);
|
||||
let portmap_fut = probe_portmap(config.test_portmap, config.local_port);
|
||||
let gateway_fut = portmap::default_gateway();
|
||||
let ipv6_fut = test_ipv6(config.test_ipv6, config.timeout);
|
||||
let port_alloc_fut = stun::detect_port_allocation(&config.stun_config);
|
||||
|
||||
let (
|
||||
stun_probes,
|
||||
relay_latencies,
|
||||
portmap_result,
|
||||
gateway_result,
|
||||
ipv6_reachable,
|
||||
port_alloc_result,
|
||||
) = tokio::join!(
|
||||
stun_fut,
|
||||
relay_fut,
|
||||
portmap_fut,
|
||||
gateway_result_fut(gateway_fut),
|
||||
ipv6_fut,
|
||||
port_alloc_fut
|
||||
);
|
||||
|
||||
// Classify NAT from STUN probes.
|
||||
let (nat_type, consensus_addr) = reflect::classify_nat(&stun_probes);
|
||||
|
||||
// Determine STUN latency (first successful probe).
|
||||
let stun_latency_ms = stun_probes.iter().filter_map(|p| p.latency_ms).min();
|
||||
|
||||
// IPv4 reachable if any STUN probe succeeded.
|
||||
let ipv4_reachable = stun_probes.iter().any(|p| p.observed_addr.is_some());
|
||||
|
||||
// Preferred relay = lowest RTT.
|
||||
let preferred_relay = relay_latencies
|
||||
.iter()
|
||||
.filter_map(|r| r.rtt_ms.map(|rtt| (r.name.clone(), rtt)))
|
||||
.min_by_key(|(_, rtt)| *rtt)
|
||||
.map(|(name, _)| name);
|
||||
|
||||
// Port mapping availability.
|
||||
let (port_mapping, nat_pmp_available, pcp_available, upnp_available) = match portmap_result {
|
||||
Some(mapping) => {
|
||||
let proto = mapping.protocol;
|
||||
(
|
||||
Some(proto),
|
||||
proto == PortMapProtocol::NatPmp,
|
||||
proto == PortMapProtocol::Pcp,
|
||||
proto == PortMapProtocol::UPnP,
|
||||
)
|
||||
}
|
||||
None => (None, false, false, false),
|
||||
};
|
||||
|
||||
let gateway = match gateway_result {
|
||||
Ok(gw) => Some(gw.to_string()),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
NetcheckReport {
|
||||
nat_type,
|
||||
reflexive_addr: consensus_addr,
|
||||
ipv4_reachable,
|
||||
ipv6_reachable,
|
||||
hairpin_works: None, // TODO: implement hairpin test
|
||||
port_mapping,
|
||||
relay_latencies,
|
||||
preferred_relay,
|
||||
stun_latency_ms,
|
||||
upnp_available,
|
||||
pcp_available,
|
||||
nat_pmp_available,
|
||||
gateway,
|
||||
duration_ms: start.elapsed().as_millis() as u32,
|
||||
stun_probes,
|
||||
port_allocation: Some(port_alloc_result.allocation),
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe relay latencies via reflect.
|
||||
async fn probe_relays(relays: &[(String, SocketAddr)], timeout: Duration) -> Vec<RelayLatency> {
|
||||
if relays.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let timeout_ms = timeout.as_millis() as u64;
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
for (name, addr) in relays {
|
||||
let name = name.clone();
|
||||
let addr = *addr;
|
||||
set.spawn(async move {
|
||||
let start = Instant::now();
|
||||
match reflect::probe_reflect_addr(addr, timeout_ms, None).await {
|
||||
Ok((_observed, _latency)) => RelayLatency {
|
||||
name,
|
||||
addr: addr.to_string(),
|
||||
rtt_ms: Some(start.elapsed().as_millis() as u32),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => RelayLatency {
|
||||
name,
|
||||
addr: addr.to_string(),
|
||||
rtt_ms: None,
|
||||
error: Some(e),
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(relays.len());
|
||||
while let Some(join_result) = set.join_next().await {
|
||||
match join_result {
|
||||
Ok(r) => results.push(r),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by RTT (lowest first).
|
||||
results.sort_by_key(|r| r.rtt_ms.unwrap_or(u32::MAX));
|
||||
results
|
||||
}
|
||||
|
||||
/// Attempt port mapping and return the mapping if successful.
|
||||
async fn probe_portmap(enabled: bool, local_port: u16) -> Option<portmap::PortMapping> {
|
||||
if !enabled || local_port == 0 {
|
||||
return None;
|
||||
}
|
||||
portmap::acquire_port_mapping(local_port, None).await.ok()
|
||||
}
|
||||
|
||||
/// Wrap the gateway future to handle the Result.
|
||||
async fn gateway_result_fut(
|
||||
fut: impl std::future::Future<Output = Result<std::net::Ipv4Addr, portmap::PortMapError>>,
|
||||
) -> Result<std::net::Ipv4Addr, portmap::PortMapError> {
|
||||
fut.await
|
||||
}
|
||||
|
||||
/// Test IPv6 connectivity by attempting to bind and send on an IPv6 socket.
|
||||
async fn test_ipv6(enabled: bool, timeout: Duration) -> bool {
|
||||
if !enabled {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to resolve and connect to an IPv6 STUN server.
|
||||
let result = tokio::time::timeout(timeout, async {
|
||||
let sock = tokio::net::UdpSocket::bind("[::]:0").await.ok()?;
|
||||
// Try Google's IPv6 STUN — if DNS resolves to an AAAA record
|
||||
// and we can send a packet, IPv6 is working.
|
||||
let addr = stun::resolve_stun_server("stun.l.google.com:19302")
|
||||
.await
|
||||
.ok()?;
|
||||
if addr.is_ipv6() {
|
||||
sock.send_to(&[0u8; 1], addr).await.ok()?;
|
||||
Some(true)
|
||||
} else {
|
||||
// Server resolved to IPv4 — try binding to [::] at least
|
||||
Some(false)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(true)) => true,
|
||||
_ => {
|
||||
// Fallback: can we at least bind an IPv6 socket?
|
||||
tokio::net::UdpSocket::bind("[::]:0").await.is_ok()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a netcheck report as a human-readable string.
|
||||
pub fn format_report(report: &NetcheckReport) -> String {
|
||||
let mut out = String::new();
|
||||
|
||||
out.push_str(&format!("=== WarzonePhone Netcheck ===\n\n"));
|
||||
out.push_str(&format!("NAT Type: {:?}\n", report.nat_type));
|
||||
out.push_str(&format!(
|
||||
"Reflexive Addr: {}\n",
|
||||
report.reflexive_addr.as_deref().unwrap_or("(unknown)")
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"IPv4: {}\n",
|
||||
if report.ipv4_reachable { "yes" } else { "no" }
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"IPv6: {}\n",
|
||||
if report.ipv6_reachable { "yes" } else { "no" }
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"Gateway: {}\n",
|
||||
report.gateway.as_deref().unwrap_or("(unknown)")
|
||||
));
|
||||
|
||||
if let Some(ref alloc) = report.port_allocation {
|
||||
out.push_str(&format!("Port Alloc: {alloc}\n"));
|
||||
}
|
||||
|
||||
out.push_str(&format!("\n--- Port Mapping ---\n"));
|
||||
out.push_str(&format!(
|
||||
"NAT-PMP: {} PCP: {} UPnP: {}\n",
|
||||
if report.nat_pmp_available {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
},
|
||||
if report.pcp_available { "yes" } else { "no" },
|
||||
if report.upnp_available { "yes" } else { "no" },
|
||||
));
|
||||
if let Some(proto) = &report.port_mapping {
|
||||
out.push_str(&format!("Active mapping: {:?}\n", proto));
|
||||
}
|
||||
|
||||
if !report.stun_probes.is_empty() {
|
||||
out.push_str(&format!("\n--- STUN Probes ---\n"));
|
||||
for p in &report.stun_probes {
|
||||
out.push_str(&format!(
|
||||
" {} → {} ({}ms){}\n",
|
||||
p.relay_name,
|
||||
p.observed_addr.as_deref().unwrap_or("failed"),
|
||||
p.latency_ms
|
||||
.map(|ms| ms.to_string())
|
||||
.unwrap_or_else(|| "-".into()),
|
||||
p.error
|
||||
.as_ref()
|
||||
.map(|e| format!(" [{e}]"))
|
||||
.unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if !report.relay_latencies.is_empty() {
|
||||
out.push_str(&format!("\n--- Relay Latencies ---\n"));
|
||||
for r in &report.relay_latencies {
|
||||
out.push_str(&format!(
|
||||
" {} ({}) → {}ms{}\n",
|
||||
r.name,
|
||||
r.addr,
|
||||
r.rtt_ms
|
||||
.map(|ms| ms.to_string())
|
||||
.unwrap_or_else(|| "-".into()),
|
||||
r.error
|
||||
.as_ref()
|
||||
.map(|e| format!(" [{e}]"))
|
||||
.unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
if let Some(ref pref) = report.preferred_relay {
|
||||
out.push_str(&format!(" Preferred: {pref}\n"));
|
||||
}
|
||||
}
|
||||
|
||||
out.push_str(&format!("\nCompleted in {}ms\n", report.duration_ms));
|
||||
out
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_has_stun_servers() {
|
||||
let config = NetcheckConfig::default();
|
||||
assert!(!config.stun_config.servers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_produces_output() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Cone,
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: None,
|
||||
relay_latencies: vec![RelayLatency {
|
||||
name: "relay-1".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(25),
|
||||
error: None,
|
||||
}],
|
||||
preferred_relay: Some("relay-1".into()),
|
||||
stun_latency_ms: Some(15),
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: Some("192.168.1.1".into()),
|
||||
duration_ms: 1500,
|
||||
stun_probes: vec![],
|
||||
port_allocation: None,
|
||||
};
|
||||
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("Cone"));
|
||||
assert!(text.contains("203.0.113.5:4433"));
|
||||
assert!(text.contains("relay-1"));
|
||||
assert!(text.contains("1500ms"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_serializes_to_json() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Cone,
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: Some(PortMapProtocol::NatPmp),
|
||||
relay_latencies: vec![],
|
||||
preferred_relay: None,
|
||||
stun_latency_ms: Some(25),
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: true,
|
||||
gateway: Some("192.168.1.1".into()),
|
||||
duration_ms: 500,
|
||||
stun_probes: vec![],
|
||||
port_allocation: Some(stun::PortAllocation::Sequential { delta: 1 }),
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
assert!(json.contains("Cone"));
|
||||
assert!(json.contains("203.0.113.5:4433"));
|
||||
assert!(json.contains("NatPmp"));
|
||||
|
||||
// Roundtrip
|
||||
let decoded: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(decoded["ipv4_reachable"], true);
|
||||
assert_eq!(decoded["ipv6_reachable"], false);
|
||||
assert_eq!(decoded["stun_latency_ms"], 25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relay_latency_serializes() {
|
||||
let lat = RelayLatency {
|
||||
name: "eu-west".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(42),
|
||||
error: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lat).unwrap();
|
||||
assert!(json.contains("eu-west"));
|
||||
assert!(json.contains("42"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_empty_relays() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Unknown,
|
||||
reflexive_addr: None,
|
||||
ipv4_reachable: false,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: None,
|
||||
relay_latencies: vec![],
|
||||
preferred_relay: None,
|
||||
stun_latency_ms: None,
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: None,
|
||||
duration_ms: 100,
|
||||
stun_probes: vec![],
|
||||
port_allocation: None,
|
||||
};
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("Unknown"));
|
||||
assert!(text.contains("(unknown)")); // reflexive addr
|
||||
assert!(text.contains("100ms"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_with_stun_probes() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::SymmetricPort,
|
||||
reflexive_addr: None,
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: true,
|
||||
hairpin_works: Some(false),
|
||||
port_mapping: Some(PortMapProtocol::UPnP),
|
||||
relay_latencies: vec![
|
||||
RelayLatency {
|
||||
name: "us-east".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(15),
|
||||
error: None,
|
||||
},
|
||||
RelayLatency {
|
||||
name: "eu-west".into(),
|
||||
addr: "10.0.0.2:4433".into(),
|
||||
rtt_ms: None,
|
||||
error: Some("timeout".into()),
|
||||
},
|
||||
],
|
||||
preferred_relay: Some("us-east".into()),
|
||||
stun_latency_ms: Some(20),
|
||||
upnp_available: true,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: Some("192.168.0.1".into()),
|
||||
duration_ms: 3000,
|
||||
stun_probes: vec![reflect::NatProbeResult {
|
||||
relay_name: "stun:google".into(),
|
||||
relay_addr: "74.125.250.129:19302".into(),
|
||||
observed_addr: Some("203.0.113.5:12345".into()),
|
||||
latency_ms: Some(20),
|
||||
error: None,
|
||||
}],
|
||||
port_allocation: Some(stun::PortAllocation::Random),
|
||||
};
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("SymmetricPort"));
|
||||
assert!(text.contains("us-east"));
|
||||
assert!(text.contains("eu-west"));
|
||||
assert!(text.contains("Preferred: us-east"));
|
||||
assert!(text.contains("UPnP: yes"));
|
||||
assert!(text.contains("stun:google"));
|
||||
assert!(text.contains("3000ms"));
|
||||
}
|
||||
|
||||
/// Integration test: run actual netcheck (requires network).
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn integration_netcheck() {
|
||||
let config = NetcheckConfig::default();
|
||||
let report = run_netcheck(&config).await;
|
||||
println!("{}", format_report(&report));
|
||||
assert!(report.duration_ms > 0);
|
||||
}
|
||||
}
|
||||
1164
crates/wzp-client/src/portmap.rs
Normal file
1164
crates/wzp-client/src/portmap.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -30,8 +30,8 @@ use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
use wzp_proto::{MediaTransport, SignalMessage};
|
||||
use wzp_transport::{client_config, create_endpoint, QuinnTransport};
|
||||
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||
use wzp_transport::{QuinnTransport, client_config, create_endpoint};
|
||||
|
||||
/// Result of one probe against one relay. Always returned so the
|
||||
/// UI can render per-relay status even when some fail.
|
||||
@@ -110,10 +110,9 @@ pub async fn probe_reflect_addr(
|
||||
let start = Instant::now();
|
||||
let probe = async {
|
||||
// Open the signal connection.
|
||||
let conn =
|
||||
wzp_transport::connect(&endpoint, relay, "_signal", client_config())
|
||||
.await
|
||||
.map_err(|e| format!("connect: {e}"))?;
|
||||
let conn = wzp_transport::connect(&endpoint, relay, "_signal", client_config())
|
||||
.await
|
||||
.map_err(|e| format!("connect: {e}"))?;
|
||||
let transport = QuinnTransport::new(conn);
|
||||
|
||||
// The relay signal handler waits for a RegisterPresence
|
||||
@@ -124,6 +123,7 @@ pub async fn probe_reflect_addr(
|
||||
// path does in desktop/src-tauri/src/lib.rs register_signal.
|
||||
transport
|
||||
.send_signal(&SignalMessage::RegisterPresence {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0u8; 32],
|
||||
signature: vec![],
|
||||
alias: None,
|
||||
@@ -151,7 +151,7 @@ pub async fn probe_reflect_addr(
|
||||
.map_err(|e| format!("send Reflect: {e}"))?;
|
||||
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::ReflectResponse { observed_addr })) => {
|
||||
Ok(Some(SignalMessage::ReflectResponse { observed_addr, .. })) => {
|
||||
let parsed: SocketAddr = observed_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?;
|
||||
@@ -473,6 +473,40 @@ pub fn classify_nat(probes: &[NatProbeResult]) -> (NatType, Option<String>) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Enhanced NAT detection that combines relay-based reflection with
|
||||
/// public STUN server probes for more robust classification.
|
||||
///
|
||||
/// Runs both probe sets concurrently:
|
||||
/// 1. Relay probes via `detect_nat_type` (existing behavior)
|
||||
/// 2. Public STUN probes via `probe_stun_servers`
|
||||
///
|
||||
/// Merges all results and classifies. More probes = higher confidence
|
||||
/// in the NAT type classification. Falls back gracefully: if STUN
|
||||
/// servers are unreachable, relay probes still work (and vice versa).
|
||||
pub async fn detect_nat_type_with_stun(
|
||||
relays: Vec<(String, SocketAddr)>,
|
||||
timeout_ms: u64,
|
||||
shared_endpoint: Option<wzp_transport::Endpoint>,
|
||||
stun_config: &crate::stun::StunConfig,
|
||||
) -> NatDetection {
|
||||
// Run relay probes and STUN probes concurrently.
|
||||
let relay_fut = detect_nat_type(relays, timeout_ms, shared_endpoint);
|
||||
let stun_fut = crate::stun::probe_stun_servers(stun_config);
|
||||
|
||||
let (relay_detection, stun_probes) = tokio::join!(relay_fut, stun_fut);
|
||||
|
||||
// Merge all probes and re-classify.
|
||||
let mut all_probes = relay_detection.probes;
|
||||
all_probes.extend(stun_probes);
|
||||
|
||||
let (nat_type, consensus_addr) = classify_nat(&all_probes);
|
||||
NatDetection {
|
||||
probes: all_probes,
|
||||
nat_type,
|
||||
consensus_addr,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Unit tests for the pure classifier ───────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -506,10 +540,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn classify_two_identical_is_cone() {
|
||||
let probes = vec![
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
];
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:4433"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Cone);
|
||||
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
|
||||
@@ -517,10 +548,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn classify_same_ip_different_ports_is_symmetric() {
|
||||
let probes = vec![
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
mk(Some("192.0.2.1:51234")),
|
||||
];
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:51234"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::SymmetricPort);
|
||||
assert!(addr.is_none());
|
||||
@@ -528,10 +556,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn classify_different_ips_is_multiple() {
|
||||
let probes = vec![
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
mk(Some("198.51.100.9:4433")),
|
||||
];
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("198.51.100.9:4433"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Multiple);
|
||||
assert!(addr.is_none());
|
||||
@@ -557,9 +582,9 @@ mod tests {
|
||||
#[test]
|
||||
fn classify_drops_loopback_probes() {
|
||||
let probes = vec![
|
||||
mk(Some("127.0.0.1:4433")), // loopback — must be dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:4433")), // public, same addr
|
||||
mk(Some("127.0.0.1:4433")), // loopback — must be dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:4433")), // public, same addr
|
||||
];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
// Two public probes with identical addrs → Cone.
|
||||
@@ -574,9 +599,9 @@ mod tests {
|
||||
// client with a 100.64/10 addr is on the same CGNAT
|
||||
// network and can't contribute to public NAT classification.
|
||||
let probes = vec![
|
||||
mk(Some("100.64.0.42:4433")), // CGNAT — dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:12345")), // public, different port
|
||||
mk(Some("100.64.0.42:4433")), // CGNAT — dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:12345")), // public, different port
|
||||
];
|
||||
let (nt, _) = classify_nat(&probes);
|
||||
// Two public probes same IP different port → SymmetricPort.
|
||||
|
||||
337
crates/wzp-client/src/relay_map.rs
Normal file
337
crates/wzp-client/src/relay_map.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Phase 8 (Tailscale-inspired): Relay map for automatic relay
|
||||
//! selection based on latency.
|
||||
//!
|
||||
//! Maintains a sorted list of known relays with their measured
|
||||
//! latencies. Used during call setup to pick the lowest-latency
|
||||
//! relay, and by netcheck to report relay health.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
/// A known relay endpoint with measured latency.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RelayEntry {
|
||||
/// Human-readable name (e.g., "us-east", "eu-west").
|
||||
pub name: String,
|
||||
/// Relay address.
|
||||
pub addr: SocketAddr,
|
||||
/// Geographic region (from RegisterPresenceAck).
|
||||
pub region: Option<String>,
|
||||
/// Last measured RTT (ms).
|
||||
pub rtt_ms: Option<u32>,
|
||||
/// When the RTT was last measured.
|
||||
#[serde(skip)]
|
||||
pub last_probed: Option<Instant>,
|
||||
/// Whether this relay is currently reachable.
|
||||
pub reachable: bool,
|
||||
}
|
||||
|
||||
/// Sorted relay map. Entries are ordered by RTT (lowest first).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RelayMap {
|
||||
entries: Vec<RelayEntry>,
|
||||
}
|
||||
|
||||
impl RelayMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add or update a relay entry.
|
||||
pub fn upsert(&mut self, name: &str, addr: SocketAddr, region: Option<String>) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.name = name.to_string();
|
||||
if region.is_some() {
|
||||
entry.region = region;
|
||||
}
|
||||
} else {
|
||||
self.entries.push(RelayEntry {
|
||||
name: name.to_string(),
|
||||
addr,
|
||||
region,
|
||||
rtt_ms: None,
|
||||
last_probed: None,
|
||||
reachable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Update RTT measurement for a relay.
|
||||
pub fn update_rtt(&mut self, addr: SocketAddr, rtt_ms: u32) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.rtt_ms = Some(rtt_ms);
|
||||
entry.last_probed = Some(Instant::now());
|
||||
entry.reachable = true;
|
||||
}
|
||||
self.sort();
|
||||
}
|
||||
|
||||
/// Mark a relay as unreachable.
|
||||
pub fn mark_unreachable(&mut self, addr: SocketAddr) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.reachable = false;
|
||||
entry.last_probed = Some(Instant::now());
|
||||
}
|
||||
self.sort();
|
||||
}
|
||||
|
||||
/// Get the preferred (lowest-latency, reachable) relay.
|
||||
pub fn preferred(&self) -> Option<&RelayEntry> {
|
||||
self.entries
|
||||
.iter()
|
||||
.find(|e| e.reachable && e.rtt_ms.is_some())
|
||||
}
|
||||
|
||||
/// Get all entries, sorted by RTT.
|
||||
pub fn entries(&self) -> &[RelayEntry] {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
/// Populate from a `RegisterPresenceAck.available_relays` list.
|
||||
/// Each entry is "name|addr" format.
|
||||
pub fn populate_from_ack(&mut self, relays: &[String], relay_region: Option<&str>) {
|
||||
for entry_str in relays {
|
||||
if let Some((name, addr_str)) = entry_str.split_once('|') {
|
||||
if let Ok(addr) = addr_str.parse::<SocketAddr>() {
|
||||
self.upsert(name, addr, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
// If the ack included a region for the current relay, we
|
||||
// could tag it — but we'd need to know which relay we're
|
||||
// connected to. Left for the caller to handle.
|
||||
let _ = relay_region;
|
||||
}
|
||||
|
||||
/// Check if any entry has a stale probe (older than `max_age`).
|
||||
pub fn needs_reprobe(&self, max_age: Duration) -> bool {
|
||||
self.entries.iter().any(|e| match e.last_probed {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > max_age,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get entries that need reprobing.
|
||||
pub fn stale_entries(&self, max_age: Duration) -> Vec<(String, SocketAddr)> {
|
||||
self.entries
|
||||
.iter()
|
||||
.filter(|e| match e.last_probed {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > max_age,
|
||||
})
|
||||
.map(|e| (e.name.clone(), e.addr))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sort(&mut self) {
|
||||
self.entries.sort_by_key(|e| {
|
||||
if e.reachable {
|
||||
e.rtt_ms.unwrap_or(u32::MAX)
|
||||
} else {
|
||||
u32::MAX
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_lowest_rtt() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
let a3: SocketAddr = "10.0.0.3:4433".parse().unwrap();
|
||||
|
||||
map.upsert("slow", a1, None);
|
||||
map.upsert("fast", a2, None);
|
||||
map.upsert("mid", a3, None);
|
||||
|
||||
map.update_rtt(a1, 200);
|
||||
map.update_rtt(a2, 15);
|
||||
map.update_rtt(a3, 80);
|
||||
|
||||
let pref = map.preferred().unwrap();
|
||||
assert_eq!(pref.addr, a2);
|
||||
assert_eq!(pref.rtt_ms, Some(15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unreachable_not_preferred() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
|
||||
map.upsert("fast-dead", a1, None);
|
||||
map.upsert("slow-alive", a2, None);
|
||||
|
||||
map.update_rtt(a1, 5);
|
||||
map.update_rtt(a2, 200);
|
||||
map.mark_unreachable(a1);
|
||||
|
||||
let pref = map.preferred().unwrap();
|
||||
assert_eq!(pref.addr, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn populate_from_ack() {
|
||||
let mut map = RelayMap::new();
|
||||
map.populate_from_ack(
|
||||
&[
|
||||
"us-east|203.0.113.5:4433".into(),
|
||||
"eu-west|198.51.100.9:4433".into(),
|
||||
],
|
||||
Some("us-east"),
|
||||
);
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
assert_eq!(map.entries()[0].name, "us-east");
|
||||
assert_eq!(map.entries()[1].name, "eu-west");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_updates_existing() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("old-name", addr, None);
|
||||
map.upsert("new-name", addr, Some("us-west".into()));
|
||||
assert_eq!(map.entries().len(), 1);
|
||||
assert_eq!(map.entries()[0].name, "new-name");
|
||||
assert_eq!(map.entries()[0].region, Some("us-west".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_preserves_region_when_none() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, Some("eu-west".into()));
|
||||
map.upsert("relay", addr, None); // region is None
|
||||
// Should keep the original region
|
||||
assert_eq!(map.entries()[0].region, Some("eu-west".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_none_on_empty() {
|
||||
let map = RelayMap::new();
|
||||
assert!(map.preferred().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_none_when_all_unreachable() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, None);
|
||||
// Not update_rtt'd, so reachable=false
|
||||
assert!(map.preferred().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_empty_is_false() {
|
||||
let map = RelayMap::new();
|
||||
// No entries → nothing to reprobe
|
||||
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_never_probed() {
|
||||
let mut map = RelayMap::new();
|
||||
map.upsert("relay", "10.0.0.1:4433".parse().unwrap(), None);
|
||||
assert!(map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_fresh_is_false() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, None);
|
||||
map.update_rtt(addr, 50);
|
||||
// Just probed, so 60s max_age should not trigger
|
||||
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_entries_returns_unprobed() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("probed", a1, None);
|
||||
map.upsert("stale", a2, None);
|
||||
map.update_rtt(a1, 50);
|
||||
|
||||
let stale = map.stale_entries(Duration::from_secs(60));
|
||||
assert_eq!(stale.len(), 1);
|
||||
assert_eq!(stale[0].1, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_stability_with_equal_rtt() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("first", a1, None);
|
||||
map.upsert("second", a2, None);
|
||||
map.update_rtt(a1, 50);
|
||||
map.update_rtt(a2, 50);
|
||||
|
||||
// Both have same RTT — sort should be stable (insertion order)
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
// Both are valid preferred relays
|
||||
assert!(map.preferred().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn populate_from_ack_skips_malformed() {
|
||||
let mut map = RelayMap::new();
|
||||
map.populate_from_ack(
|
||||
&[
|
||||
"good|10.0.0.1:4433".into(),
|
||||
"no-pipe-separator".into(),
|
||||
"bad-addr|not-a-socket-addr".into(),
|
||||
"also-good|10.0.0.2:4433".into(),
|
||||
],
|
||||
None,
|
||||
);
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_unreachable_sorts_to_end() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("fast", a1, None);
|
||||
map.upsert("slow", a2, None);
|
||||
map.update_rtt(a1, 10);
|
||||
map.update_rtt(a2, 200);
|
||||
|
||||
assert_eq!(map.preferred().unwrap().addr, a1);
|
||||
|
||||
map.mark_unreachable(a1);
|
||||
assert_eq!(map.preferred().unwrap().addr, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relay_entry_serializes() {
|
||||
let entry = RelayEntry {
|
||||
name: "test".into(),
|
||||
addr: "10.0.0.1:4433".parse().unwrap(),
|
||||
region: Some("us-east".into()),
|
||||
rtt_ms: Some(42),
|
||||
last_probed: Some(Instant::now()),
|
||||
reachable: true,
|
||||
};
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
assert!(json.contains("us-east"));
|
||||
assert!(json.contains("42"));
|
||||
// last_probed is #[serde(skip)]
|
||||
assert!(!json.contains("last_probed"));
|
||||
}
|
||||
}
|
||||
1445
crates/wzp-client/src/stun.rs
Normal file
1445
crates/wzp-client/src/stun.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -72,8 +72,7 @@ fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
|
||||
/// decoder, pushes frames through the pipeline, and collects statistics.
|
||||
/// Combinations where `target_depth > max_depth` are skipped.
|
||||
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
|
||||
let frames_per_config =
|
||||
(config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||
let frames_per_config = (config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use wzp_client::dual_path::{race, PeerCandidates, WinningPath};
|
||||
use wzp_client::dual_path::{PeerCandidates, WinningPath, race};
|
||||
use wzp_client::reflect::Role;
|
||||
use wzp_transport::{create_endpoint, server_config};
|
||||
|
||||
@@ -113,17 +113,27 @@ async fn dual_path_direct_wins_on_loopback() {
|
||||
PeerCandidates {
|
||||
reflexive: Some(acceptor_listen_addr),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
relay_addr,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await
|
||||
.expect("race must succeed");
|
||||
|
||||
assert!(result.direct_transport.is_some(), "direct transport should be available");
|
||||
assert_eq!(result.local_winner, WinningPath::Direct, "direct should win on loopback");
|
||||
assert!(
|
||||
result.direct_transport.is_some(),
|
||||
"direct transport should be available"
|
||||
);
|
||||
assert_eq!(
|
||||
result.local_winner,
|
||||
WinningPath::Direct,
|
||||
"direct should win on loopback"
|
||||
);
|
||||
|
||||
// Cancel the acceptor accept task so the test finishes.
|
||||
acceptor_accept_task.abort();
|
||||
@@ -155,16 +165,22 @@ async fn dual_path_relay_wins_when_direct_is_dead() {
|
||||
PeerCandidates {
|
||||
reflexive: Some(dead_peer),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
relay_addr,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await
|
||||
.expect("race must succeed via relay fallback");
|
||||
|
||||
assert!(result.relay_transport.is_some(), "relay transport should be available");
|
||||
assert!(
|
||||
result.relay_transport.is_some(),
|
||||
"relay transport should be available"
|
||||
);
|
||||
assert_eq!(
|
||||
result.local_winner,
|
||||
WinningPath::Relay,
|
||||
@@ -193,11 +209,14 @@ async fn dual_path_errors_cleanly_when_both_paths_dead() {
|
||||
PeerCandidates {
|
||||
reflexive: Some(dead_peer),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
dead_relay,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await;
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use wzp_proto::packet::MediaPacket;
|
||||
use wzp_proto::traits::{MediaTransport, PathQuality};
|
||||
use wzp_proto::{SignalMessage, TransportError};
|
||||
use wzp_proto::{SignalMessage, TransportError, default_signal_version};
|
||||
|
||||
/// A mock transport backed by two mpsc channels (one per direction).
|
||||
///
|
||||
@@ -83,11 +83,15 @@ async fn full_handshake_both_sides_derive_same_session() {
|
||||
|
||||
// Run client and relay handshakes concurrently.
|
||||
let (client_result, relay_result) = tokio::join!(
|
||||
wzp_client::handshake::perform_handshake(client_transport_clone.as_ref(), &client_seed, None),
|
||||
wzp_client::handshake::perform_handshake(
|
||||
client_transport_clone.as_ref(),
|
||||
&client_seed,
|
||||
None
|
||||
),
|
||||
wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed),
|
||||
);
|
||||
|
||||
let mut client_session = client_result.expect("client handshake should succeed");
|
||||
let client_hs = client_result.expect("client handshake should succeed");
|
||||
let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) =
|
||||
relay_result.expect("relay handshake should succeed");
|
||||
|
||||
@@ -95,31 +99,53 @@ async fn full_handshake_both_sides_derive_same_session() {
|
||||
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
|
||||
|
||||
// Verify both sides can communicate: client encrypts, relay decrypts.
|
||||
let header = b"test-header";
|
||||
// encrypt/decrypt derive nonces from MediaHeader.seq, so we need valid headers.
|
||||
use wzp_proto::packet::MediaHeader;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
let make_hdr = |seq: u32| {
|
||||
let h = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq.wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut b = Vec::new();
|
||||
h.write_to(&mut b);
|
||||
b
|
||||
};
|
||||
|
||||
let header = make_hdr(0);
|
||||
let plaintext = b"hello from client to relay";
|
||||
|
||||
let mut client_session = client_hs.session;
|
||||
let mut ciphertext = Vec::new();
|
||||
client_session
|
||||
.encrypt(header, plaintext, &mut ciphertext)
|
||||
.encrypt(&header, plaintext, &mut ciphertext)
|
||||
.expect("client encrypt should succeed");
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
relay_session
|
||||
.decrypt(header, &ciphertext, &mut decrypted)
|
||||
.decrypt(&header, &ciphertext, &mut decrypted)
|
||||
.expect("relay decrypt should succeed");
|
||||
|
||||
assert_eq!(&decrypted[..], plaintext);
|
||||
|
||||
// Verify reverse direction: relay encrypts, client decrypts.
|
||||
let header2 = make_hdr(0); // relay's send_seq starts at 0
|
||||
let plaintext2 = b"hello from relay to client";
|
||||
let mut ciphertext2 = Vec::new();
|
||||
relay_session
|
||||
.encrypt(header, plaintext2, &mut ciphertext2)
|
||||
.encrypt(&header2, plaintext2, &mut ciphertext2)
|
||||
.expect("relay encrypt should succeed");
|
||||
|
||||
let mut decrypted2 = Vec::new();
|
||||
client_session
|
||||
.decrypt(header, &ciphertext2, &mut decrypted2)
|
||||
.decrypt(&header2, &ciphertext2, &mut decrypted2)
|
||||
.expect("client decrypt should succeed");
|
||||
|
||||
assert_eq!(&decrypted2[..], plaintext2);
|
||||
@@ -147,11 +173,15 @@ async fn handshake_rejects_tampered_signature() {
|
||||
let bad_signature = kx.sign(b"wrong-data-intentionally");
|
||||
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature: bad_signature,
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![],
|
||||
};
|
||||
client_transport_clone
|
||||
.send_signal(&offer)
|
||||
@@ -175,3 +205,42 @@ async fn handshake_rejects_tampered_signature() {
|
||||
Ok(_) => panic!("relay should reject tampered signature"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_receives_protocol_version_mismatch() {
|
||||
let (client_transport, relay_transport) = MockTransport::pair();
|
||||
|
||||
let client_seed = [0xAA_u8; 32];
|
||||
|
||||
// Spawn a fake relay that sends ProtocolVersionMismatch.
|
||||
let relay_clone = Arc::clone(&relay_transport);
|
||||
tokio::spawn(async move {
|
||||
// Wait for the client's CallOffer.
|
||||
let offer = relay_clone.recv_signal().await.unwrap().unwrap();
|
||||
assert!(matches!(offer, SignalMessage::CallOffer { .. }));
|
||||
|
||||
// Respond with ProtocolVersionMismatch.
|
||||
let mismatch = SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
|
||||
server_supported: vec![3],
|
||||
},
|
||||
call_id: None,
|
||||
};
|
||||
relay_clone.send_signal(&mismatch).await.unwrap();
|
||||
});
|
||||
|
||||
let result =
|
||||
wzp_client::handshake::perform_handshake(client_transport.as_ref(), &client_seed, None)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(wzp_client::handshake::HandshakeError::ProtocolVersionMismatch {
|
||||
server_supported,
|
||||
}) => {
|
||||
assert_eq!(server_supported, vec![3]);
|
||||
}
|
||||
Err(other) => panic!("expected ProtocolVersionMismatch, got: {other:?}"),
|
||||
Ok(_) => panic!("expected handshake to fail with ProtocolVersionMismatch"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +83,12 @@ fn long_session_no_drift() {
|
||||
println!(
|
||||
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
stats.underruns,
|
||||
stats.overruns,
|
||||
stats.current_depth,
|
||||
stats.max_depth_seen,
|
||||
stats.packets_late,
|
||||
stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
|
||||
@@ -123,7 +127,7 @@ fn long_session_with_simulated_loss() {
|
||||
|
||||
for (j, pkt) in batch.into_iter().enumerate() {
|
||||
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
|
||||
if !pkt.header.is_repair && i % 20 == 0 && j == 0 {
|
||||
if !pkt.header.is_repair() && i % 20 == 0 && j == 0 {
|
||||
continue; // drop this packet
|
||||
}
|
||||
decoder.ingest(pkt);
|
||||
@@ -139,8 +143,12 @@ fn long_session_with_simulated_loss() {
|
||||
println!(
|
||||
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
stats.underruns,
|
||||
stats.overruns,
|
||||
stats.current_depth,
|
||||
stats.max_depth_seen,
|
||||
stats.packets_late,
|
||||
stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
|
||||
@@ -150,6 +158,65 @@ fn long_session_with_simulated_loss() {
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that `MediaHeader::timestamp` continues monotonically across
|
||||
/// rekey boundaries. Rekey is a crypto-layer operation (key material
|
||||
/// rotation) and must not reset or interfere with framing state.
|
||||
///
|
||||
/// We simulate a 3000-frame session with two conceptual rekeys at frames
|
||||
/// 1000 and 2000. The encoder's timestamp counter must advance
|
||||
/// monotonically throughout.
|
||||
#[test]
|
||||
fn rekey_timestamp_monotonic() {
|
||||
let config = test_config();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
|
||||
let mut timestamps = Vec::new();
|
||||
|
||||
// Phase 1: before first rekey
|
||||
for i in 0..1000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: between first and second rekey
|
||||
for i in 1000..2000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: after second rekey
|
||||
for i in 2000..3000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Assert strict monotonicity (non-decreasing) across all three phases.
|
||||
for window in timestamps.windows(2) {
|
||||
assert!(
|
||||
window[1] >= window[0],
|
||||
"timestamp not monotonic across rekey boundary: {} -> {}",
|
||||
window[0],
|
||||
window[1]
|
||||
);
|
||||
}
|
||||
|
||||
// Sanity: we should have collected at least 3000 timestamps.
|
||||
assert!(
|
||||
timestamps.len() >= 3000,
|
||||
"expected >= 3000 timestamps, got {}",
|
||||
timestamps.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that the jitter buffer's decoded-frame count is consistent with its
|
||||
/// own internal statistics over a long session.
|
||||
#[test]
|
||||
|
||||
@@ -114,11 +114,7 @@ impl EchoCanceller {
|
||||
/// Number of delayed samples available to release.
|
||||
fn delay_available(&self) -> usize {
|
||||
let buffered = self.delay_write - self.delay_read;
|
||||
if buffered > self.delay_samples {
|
||||
buffered - self.delay_samples
|
||||
} else {
|
||||
0
|
||||
}
|
||||
buffered.saturating_sub(self.delay_samples)
|
||||
}
|
||||
|
||||
/// Process a near-end (microphone) frame, removing the estimated echo.
|
||||
@@ -161,8 +157,8 @@ impl EchoCanceller {
|
||||
let mut sum_near_sq: f64 = 0.0;
|
||||
let mut sum_err_sq: f64 = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let near_f = nearend[i] as f32;
|
||||
for (i, sample) in nearend.iter_mut().enumerate() {
|
||||
let near_f = *sample as f32;
|
||||
|
||||
// Position of far-end "now" for this near-end sample.
|
||||
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
||||
@@ -190,7 +186,7 @@ impl EchoCanceller {
|
||||
}
|
||||
|
||||
let out = error.clamp(-32768.0, 32767.0);
|
||||
nearend[i] = out as i16;
|
||||
*sample = out as i16;
|
||||
|
||||
sum_near_sq += (near_f as f64).powi(2);
|
||||
sum_err_sq += (out as f64).powi(2);
|
||||
@@ -325,7 +321,10 @@ mod tests {
|
||||
// Feed 960 samples (= delay amount). No samples released yet.
|
||||
aec.feed_farend(&vec![1i16; 960]);
|
||||
// far_buf should still be all zeros (nothing released).
|
||||
assert!(aec.far_buf.iter().all(|&s| s == 0.0), "nothing should be released yet");
|
||||
assert!(
|
||||
aec.far_buf.iter().all(|&s| s == 0.0),
|
||||
"nothing should be released yet"
|
||||
);
|
||||
|
||||
// Feed 480 more. 480 should be released to far_buf.
|
||||
aec.feed_farend(&vec![2i16; 480]);
|
||||
|
||||
@@ -24,12 +24,12 @@ impl AutoGainControl {
|
||||
/// Create a new AGC with sensible VoIP defaults.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
target_rms: 3000.0, // ~-20 dBFS for i16
|
||||
target_rms: 3000.0, // ~-20 dBFS for i16
|
||||
current_gain: 1.0,
|
||||
min_gain: 0.5,
|
||||
max_gain: 32.0,
|
||||
attack_alpha: 0.3, // fast attack
|
||||
release_alpha: 0.02, // slow release
|
||||
attack_alpha: 0.3, // fast attack
|
||||
release_alpha: 0.02, // slow release
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
@@ -211,9 +211,6 @@ mod tests {
|
||||
fn agc_gain_db_at_unity() {
|
||||
let agc = AutoGainControl::new();
|
||||
let db = agc.current_gain_db();
|
||||
assert!(
|
||||
db.abs() < 0.01,
|
||||
"expected ~0 dB at unity gain, got {db}"
|
||||
);
|
||||
assert!(db.abs() < 0.01, "expected ~0 dB at unity gain, got {db}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Codec2Decoder {
|
||||
|
||||
/// Number of compressed bytes per frame.
|
||||
fn bytes_per_frame(&self) -> usize {
|
||||
(self.inner.bits_per_frame() + 7) / 8
|
||||
self.inner.bits_per_frame().div_ceil(8)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Codec2Encoder {
|
||||
|
||||
/// Number of compressed bytes per frame.
|
||||
fn bytes_per_frame(&self) -> usize {
|
||||
(self.inner.bits_per_frame() + 7) / 8
|
||||
self.inner.bits_per_frame().div_ceil(8)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ impl NoiseSupressor {
|
||||
|
||||
// f32 → i16 with clamping
|
||||
for (i, &val) in output.iter().enumerate() {
|
||||
let clamped = val.max(-32768.0).min(32767.0);
|
||||
let clamped = val.clamp(-32768.0, 32767.0);
|
||||
pcm[offset + i] = clamped as i16;
|
||||
}
|
||||
}
|
||||
@@ -99,7 +99,11 @@ mod tests {
|
||||
}
|
||||
let original_len = pcm.len();
|
||||
ns.process(&mut pcm);
|
||||
assert_eq!(pcm.len(), original_len, "output length must match input length");
|
||||
assert_eq!(
|
||||
pcm.len(),
|
||||
original_len,
|
||||
"output length must match input length"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -71,9 +71,8 @@ impl DecoderHandle {
|
||||
"opus_decoder_create failed: err={error}"
|
||||
)));
|
||||
}
|
||||
let inner = NonNull::new(ptr).ok_or_else(|| {
|
||||
CodecError::DecodeFailed("opus_decoder_create returned null".into())
|
||||
})?;
|
||||
let inner = NonNull::new(ptr)
|
||||
.ok_or_else(|| CodecError::DecodeFailed("opus_decoder_create returned null".into()))?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
@@ -257,11 +256,7 @@ impl DredDecoderHandle {
|
||||
/// The `dred_end` output is the silence gap at the tail of the DRED
|
||||
/// window; we subtract it from the total offset to give callers the
|
||||
/// truly usable sample count.
|
||||
pub fn parse_into(
|
||||
&mut self,
|
||||
state: &mut DredState,
|
||||
packet: &[u8],
|
||||
) -> Result<i32, CodecError> {
|
||||
pub fn parse_into(&mut self, state: &mut DredState, packet: &[u8]) -> Result<i32, CodecError> {
|
||||
if packet.is_empty() {
|
||||
state.samples_available = 0;
|
||||
return Ok(0);
|
||||
@@ -545,7 +540,10 @@ mod tests {
|
||||
// to our sine wave because we fed a cold decoder only one warmup
|
||||
// frame, but it should still produce non-silent speech-like output
|
||||
// since the DRED state was parsed from real speech content.
|
||||
let energy: u64 = recon_pcm.iter().map(|&s| (s as i32).unsigned_abs() as u64).sum();
|
||||
let energy: u64 = recon_pcm
|
||||
.iter()
|
||||
.map(|&s| (s as i32).unsigned_abs() as u64)
|
||||
.sum();
|
||||
assert!(
|
||||
energy > 0,
|
||||
"reconstructed audio has zero total energy — DRED reconstruction produced silence"
|
||||
|
||||
@@ -53,10 +53,7 @@ pub fn set_dred_verbose_logs(enabled: bool) {
|
||||
/// The returned encoder accepts 48 kHz mono PCM regardless of the active
|
||||
/// codec; resampling is handled internally when Codec2 is selected.
|
||||
pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
||||
Box::new(
|
||||
AdaptiveEncoder::new(profile)
|
||||
.expect("failed to create adaptive encoder"),
|
||||
)
|
||||
Box::new(AdaptiveEncoder::new(profile).expect("failed to create adaptive encoder"))
|
||||
}
|
||||
|
||||
/// Create an adaptive decoder starting at the given quality profile.
|
||||
@@ -64,10 +61,7 @@ pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
||||
/// The returned decoder always produces 48 kHz mono PCM; upsampling from
|
||||
/// Codec2's native 8 kHz is handled internally.
|
||||
pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> {
|
||||
Box::new(
|
||||
AdaptiveDecoder::new(profile)
|
||||
.expect("failed to create adaptive decoder"),
|
||||
)
|
||||
Box::new(AdaptiveDecoder::new(profile).expect("failed to create adaptive decoder"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -82,6 +76,10 @@ mod codec2_tests {
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: wzp_proto::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,7 +208,10 @@ mod codec2_tests {
|
||||
|
||||
let mut pcm_out_c2 = vec![0i16; 1920];
|
||||
let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap();
|
||||
assert_eq!(samples_c2, 1920, "should get 1920 samples at 48kHz after upsample");
|
||||
assert_eq!(
|
||||
samples_c2, 1920,
|
||||
"should get 1920 samples at 48kHz after upsample"
|
||||
);
|
||||
|
||||
// Step 3: Switch back to Opus.
|
||||
enc.set_profile(QualityProfile::GOOD).unwrap();
|
||||
|
||||
@@ -85,8 +85,13 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
|
||||
// offsets, so the extra window costs only ~1-2 kbps additional overhead
|
||||
// while buying substantially better burst resilience (up from 500 ms).
|
||||
CodecId::Opus6k => 104,
|
||||
// Non-Opus (Codec2 / CN): DRED is N/A.
|
||||
CodecId::Codec2_1200 | CodecId::Codec2_3200 | CodecId::ComfortNoise => 0,
|
||||
// Non-Opus (Codec2 / CN / video): DRED is N/A.
|
||||
CodecId::Codec2_1200
|
||||
| CodecId::Codec2_3200
|
||||
| CodecId::ComfortNoise
|
||||
| CodecId::H264Baseline
|
||||
| CodecId::H265Main
|
||||
| CodecId::Av1Main => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +101,7 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
|
||||
/// mode; unset or empty leaves DRED enabled.
|
||||
fn read_legacy_fec_env() -> bool {
|
||||
match std::env::var(LEGACY_FEC_ENV) {
|
||||
Ok(v) => !v.is_empty() && v != "0" && v.to_ascii_lowercase() != "false",
|
||||
Ok(v) => !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false"),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
@@ -247,7 +252,7 @@ impl OpusEncoder {
|
||||
let clamped = if self.legacy_fec_mode {
|
||||
loss_pct.min(100)
|
||||
} else {
|
||||
loss_pct.max(DRED_LOSS_FLOOR_PCT).min(100)
|
||||
loss_pct.clamp(DRED_LOSS_FLOOR_PCT, 100)
|
||||
};
|
||||
let _ = self.inner.set_packet_loss(clamped);
|
||||
}
|
||||
@@ -332,7 +337,11 @@ impl AudioEncoder for OpusEncoder {
|
||||
);
|
||||
return;
|
||||
}
|
||||
let mode = if enabled { InbandFec::Mode1 } else { InbandFec::Off };
|
||||
let mode = if enabled {
|
||||
InbandFec::Mode1
|
||||
} else {
|
||||
InbandFec::Off
|
||||
};
|
||||
let _ = self.inner.set_inband_fec(mode);
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
|
||||
let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5)
|
||||
let beta_denom = bessel_i0(KAISER_BETA);
|
||||
|
||||
for i in 0..FIR_TAPS {
|
||||
for (i, slot) in kernel.iter_mut().enumerate() {
|
||||
// Sinc
|
||||
let n = i as f64 - m / 2.0;
|
||||
let sinc = if n.abs() < 1e-12 {
|
||||
@@ -61,7 +61,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
|
||||
let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1]
|
||||
let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom;
|
||||
|
||||
kernel[i] = sinc * kaiser;
|
||||
*slot = sinc * kaiser;
|
||||
}
|
||||
|
||||
// Normalise to unity DC gain.
|
||||
@@ -129,8 +129,7 @@ impl Downsampler48to8 {
|
||||
|
||||
// Update history: keep the last (FIR_TAPS - 1) samples from work.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
// Input was shorter than history — shift.
|
||||
let shift = hist_len - work.len();
|
||||
@@ -181,9 +180,7 @@ impl Upsampler8to48 {
|
||||
work.extend_from_slice(&self.history);
|
||||
for &s in input {
|
||||
work.push(s as f64);
|
||||
for _ in 1..RATIO {
|
||||
work.push(0.0);
|
||||
}
|
||||
work.resize(work.len() + (RATIO - 1), 0.0f64);
|
||||
}
|
||||
|
||||
let out_len = stuffed_len;
|
||||
@@ -209,8 +206,7 @@ impl Upsampler8to48 {
|
||||
|
||||
// Update history.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
let shift = hist_len - work.len();
|
||||
self.history.copy_within(shift.., 0);
|
||||
|
||||
@@ -151,7 +151,10 @@ mod tests {
|
||||
for _ in 0..4 {
|
||||
det.is_silent(&silence);
|
||||
}
|
||||
assert!(det.is_silent(&silence), "should be suppressing after hangover");
|
||||
assert!(
|
||||
det.is_silent(&silence),
|
||||
"should be suppressing after hangover"
|
||||
);
|
||||
|
||||
// Speech arrives — should immediately stop suppressing.
|
||||
assert!(!det.is_silent(&speech));
|
||||
@@ -165,10 +168,16 @@ mod tests {
|
||||
cn.generate(&mut pcm);
|
||||
|
||||
// At least some samples should be non-zero.
|
||||
assert!(pcm.iter().any(|&s| s != 0), "CN output should not be all zeros");
|
||||
assert!(
|
||||
pcm.iter().any(|&s| s != 0),
|
||||
"CN output should not be all zeros"
|
||||
);
|
||||
|
||||
// All samples should be within [-50, 50].
|
||||
assert!(pcm.iter().all(|&s| s.abs() <= 50), "CN samples out of range");
|
||||
assert!(
|
||||
pcm.iter().all(|&s| s.abs() <= 50),
|
||||
"CN samples out of range"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -179,11 +188,17 @@ mod tests {
|
||||
// Constant value: RMS of [v, v, v, ...] = |v|.
|
||||
let pcm = vec![100i16; 100];
|
||||
let rms = SilenceDetector::rms(&pcm);
|
||||
assert!((rms - 100.0).abs() < 0.01, "RMS of constant 100 should be 100, got {rms}");
|
||||
assert!(
|
||||
(rms - 100.0).abs() < 0.01,
|
||||
"RMS of constant 100 should be 100, got {rms}"
|
||||
);
|
||||
|
||||
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
|
||||
let rms2 = SilenceDetector::rms(&[3, 4]);
|
||||
assert!((rms2 - 3.5355).abs() < 0.01, "RMS of [3,4] should be ~3.5355, got {rms2}");
|
||||
assert!(
|
||||
(rms2 - 3.5355).abs() < 0.01,
|
||||
"RMS of [3,4] should be ~3.5355, got {rms2}"
|
||||
);
|
||||
|
||||
// Empty buffer → 0.
|
||||
assert_eq!(SilenceDetector::rms(&[]), 0.0);
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
//! Sliding window replay protection.
|
||||
//!
|
||||
//! Tracks seen sequence numbers using a bitmap. Window size is 1024 packets.
|
||||
//! Sequence numbers that are too old (more than WINDOW_SIZE behind the highest
|
||||
//! seen) are rejected.
|
||||
//! Tracks seen sequence numbers using a bitmap. Window size is configurable
|
||||
//! at construction time. Sequence numbers that are too old (more than
|
||||
//! `window_size` behind the highest seen) are rejected.
|
||||
|
||||
use wzp_proto::CryptoError;
|
||||
|
||||
/// Window size in packets.
|
||||
const WINDOW_SIZE: u16 = 1024;
|
||||
|
||||
/// Sliding window anti-replay detector.
|
||||
///
|
||||
/// Uses a bitmap to track which sequence numbers have been seen within
|
||||
/// the current window. Handles u16 wrapping correctly.
|
||||
/// the current window. Handles `u32` wrapping correctly.
|
||||
pub struct AntiReplayWindow {
|
||||
/// Window size in packets.
|
||||
window_size: u32,
|
||||
/// Highest sequence number seen so far.
|
||||
highest: u16,
|
||||
highest: u32,
|
||||
/// Bitmap of seen packets. Bit i corresponds to (highest - i).
|
||||
bitmap: Vec<u64>,
|
||||
/// Whether any packet has been received yet.
|
||||
@@ -23,21 +22,26 @@ pub struct AntiReplayWindow {
|
||||
}
|
||||
|
||||
impl AntiReplayWindow {
|
||||
/// Number of u64 words needed for the bitmap.
|
||||
const BITMAP_WORDS: usize = (WINDOW_SIZE as usize + 63) / 64;
|
||||
|
||||
/// Create a new anti-replay window.
|
||||
/// Create a new anti-replay window with the default size of 1024 packets.
|
||||
pub fn new() -> Self {
|
||||
Self::with_window(1024)
|
||||
}
|
||||
|
||||
/// Create a new anti-replay window with a custom size.
|
||||
pub fn with_window(size: usize) -> Self {
|
||||
let window_size = size as u32;
|
||||
let bitmap_words = (size + 63) / 64;
|
||||
Self {
|
||||
window_size,
|
||||
highest: 0,
|
||||
bitmap: vec![0u64; Self::BITMAP_WORDS],
|
||||
bitmap: vec![0u64; bitmap_words],
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a sequence number is valid (not a replay, not too old).
|
||||
/// If valid, marks it as seen.
|
||||
pub fn check_and_update(&mut self, seq: u16) -> Result<(), CryptoError> {
|
||||
pub fn check_and_update(&mut self, seq: u32) -> Result<(), CryptoError> {
|
||||
if !self.initialized {
|
||||
self.initialized = true;
|
||||
self.highest = seq;
|
||||
@@ -52,17 +56,17 @@ impl AntiReplayWindow {
|
||||
return Err(CryptoError::ReplayDetected { seq });
|
||||
}
|
||||
|
||||
if diff < 0x8000 {
|
||||
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF])
|
||||
if diff < 0x8000_0000 {
|
||||
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF_FFFF])
|
||||
let shift = diff as usize;
|
||||
self.advance_window(shift);
|
||||
self.highest = seq;
|
||||
self.set_bit(0);
|
||||
Ok(())
|
||||
} else {
|
||||
// seq is behind highest (wrapping-aware: diff in [0x8000, 0xFFFF])
|
||||
// seq is behind highest (wrapping-aware: diff in [0x8000_0000, 0xFFFF_FFFF])
|
||||
let behind = self.highest.wrapping_sub(seq) as usize;
|
||||
if behind >= WINDOW_SIZE as usize {
|
||||
if behind >= self.window_size as usize {
|
||||
return Err(CryptoError::ReplayDetected { seq });
|
||||
}
|
||||
if self.get_bit(behind) {
|
||||
@@ -75,7 +79,8 @@ impl AntiReplayWindow {
|
||||
|
||||
/// Advance the window by `shift` positions (shift left = new bits at position 0).
|
||||
fn advance_window(&mut self, shift: usize) {
|
||||
if shift >= WINDOW_SIZE as usize {
|
||||
let window_size = self.window_size as usize;
|
||||
if shift >= window_size {
|
||||
for word in &mut self.bitmap {
|
||||
*word = 0;
|
||||
}
|
||||
@@ -156,7 +161,11 @@ mod tests {
|
||||
fn sequential_accepted() {
|
||||
let mut w = AntiReplayWindow::new();
|
||||
for i in 0..200 {
|
||||
assert!(w.check_and_update(i).is_ok(), "seq {} should be accepted", i);
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,11 +192,11 @@ mod tests {
|
||||
#[test]
|
||||
fn wrapping_works() {
|
||||
let mut w = AntiReplayWindow::new();
|
||||
assert!(w.check_and_update(65530).is_ok());
|
||||
assert!(w.check_and_update(65535).is_ok());
|
||||
assert!(w.check_and_update(0xFFFF_FFF0).is_ok());
|
||||
assert!(w.check_and_update(0xFFFF_FFFF).is_ok());
|
||||
assert!(w.check_and_update(0).is_ok()); // wrapped
|
||||
assert!(w.check_and_update(1).is_ok());
|
||||
assert!(w.check_and_update(65535).is_err()); // duplicate
|
||||
assert!(w.check_and_update(0xFFFF_FFFF).is_err()); // duplicate
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -201,4 +210,53 @@ mod tests {
|
||||
// Now 0 is 1024 behind 1024, which is at the boundary limit
|
||||
assert!(w.check_and_update(0).is_err()); // already seen or too old
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_window_size() {
|
||||
let mut w = AntiReplayWindow::with_window(64);
|
||||
for i in 0..64 {
|
||||
assert!(w.check_and_update(i).is_ok());
|
||||
}
|
||||
// seq 0 is now exactly at the boundary (64 behind 64)
|
||||
assert!(w.check_and_update(0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn video_burst_200_with_one_reorder() {
|
||||
let mut w = AntiReplayWindow::with_window(1024);
|
||||
// Simulate a 200-packet burst
|
||||
for i in 0..200 {
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
// One packet reordered (arrives late)
|
||||
assert!(w.check_and_update(50).is_err(), "seq 50 is a duplicate");
|
||||
// But a packet just behind the window should still be ok
|
||||
assert!(w.check_and_update(199).is_err(), "seq 199 is a duplicate");
|
||||
// Continue the burst
|
||||
for i in 200..400 {
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn u32_high_range_works() {
|
||||
let mut w = AntiReplayWindow::with_window(64);
|
||||
let base = 1000u32;
|
||||
assert!(w.check_and_update(base).is_ok());
|
||||
assert!(w.check_and_update(base + 1).is_ok());
|
||||
// 65 behind highest (base+1) is outside the 64-packet window
|
||||
assert!(w.check_and_update(base.wrapping_sub(64)).is_err());
|
||||
// 63 behind is inside
|
||||
assert!(w.check_and_update(base.wrapping_sub(62)).is_ok());
|
||||
// base itself is now a duplicate
|
||||
assert!(w.check_and_update(base).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ use ed25519_dalek::{Signer, SigningKey, Verifier, VerifyingKey};
|
||||
use hkdf::Hkdf;
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
||||
use wzp_proto::{CryptoError, CryptoSession, KeyExchange};
|
||||
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
||||
|
||||
use crate::session::ChaChaSession;
|
||||
|
||||
@@ -18,10 +18,14 @@ use crate::session::ChaChaSession;
|
||||
pub struct WarzoneKeyExchange {
|
||||
/// Ed25519 signing key (identity).
|
||||
signing_key: SigningKey,
|
||||
/// X25519 static secret (derived from seed, used for identity encryption).
|
||||
/// X25519 static secret derived from identity seed. Reserved for future
|
||||
/// use in static-key federation authentication (not used in current
|
||||
/// ephemeral-only handshake protocol).
|
||||
#[allow(dead_code)]
|
||||
x25519_static_secret: StaticSecret,
|
||||
/// X25519 static public key.
|
||||
/// X25519 static public key derived from identity seed. Reserved for
|
||||
/// future use in static-key federation authentication (not used in
|
||||
/// current ephemeral-only handshake protocol).
|
||||
#[allow(dead_code)]
|
||||
x25519_static_public: X25519PublicKey,
|
||||
/// Ephemeral X25519 secret for the current call (set by generate_ephemeral).
|
||||
@@ -91,12 +95,11 @@ impl KeyExchange for WarzoneKeyExchange {
|
||||
&self,
|
||||
peer_ephemeral_pub: &[u8; 32],
|
||||
) -> Result<Box<dyn CryptoSession>, CryptoError> {
|
||||
let secret = self
|
||||
.ephemeral_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
CryptoError::Internal("no ephemeral key generated; call generate_ephemeral first".into())
|
||||
})?;
|
||||
let secret = self.ephemeral_secret.as_ref().ok_or_else(|| {
|
||||
CryptoError::Internal(
|
||||
"no ephemeral key generated; call generate_ephemeral first".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let peer_public = X25519PublicKey::from(*peer_ephemeral_pub);
|
||||
// Use diffie_hellman with a clone of the StaticSecret
|
||||
@@ -206,18 +209,34 @@ mod tests {
|
||||
let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap();
|
||||
let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap();
|
||||
|
||||
// Verify they can communicate: Alice encrypts, Bob decrypts
|
||||
let header = b"call-header";
|
||||
// Verify they can communicate: Alice encrypts, Bob decrypts.
|
||||
// Use a valid v2 MediaHeader — encrypt/decrypt now derive the nonce from
|
||||
// header.seq and will reject raw byte slices shorter than WIRE_SIZE.
|
||||
use wzp_proto::{CodecId, MediaHeader, MediaType};
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
|
||||
let plaintext = b"hello from alice";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice_session
|
||||
.encrypt(header, plaintext, &mut ciphertext)
|
||||
.encrypt(&header_bytes, plaintext, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
bob_session
|
||||
.decrypt(header, &ciphertext, &mut decrypted)
|
||||
.decrypt(&header_bytes, &ciphertext, &mut decrypted)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(&decrypted, plaintext);
|
||||
|
||||
@@ -79,7 +79,9 @@ impl Seed {
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
|
||||
pub fn from_mnemonic(words: &str) -> Result<Self, String> {
|
||||
let mnemonic: bip39::Mnemonic = words.parse().map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||
let mnemonic: bip39::Mnemonic = words
|
||||
.parse()
|
||||
.map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||
let entropy = mnemonic.to_entropy();
|
||||
if entropy.len() != 32 {
|
||||
return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));
|
||||
|
||||
@@ -16,8 +16,8 @@ pub mod session;
|
||||
|
||||
pub use anti_replay::AntiReplayWindow;
|
||||
pub use handshake::WarzoneKeyExchange;
|
||||
pub use identity::{hash_room_name, Fingerprint, IdentityKeyPair, PublicIdentity, Seed};
|
||||
pub use nonce::{build_nonce, Direction};
|
||||
pub use identity::{Fingerprint, IdentityKeyPair, PublicIdentity, Seed, hash_room_name};
|
||||
pub use nonce::{Direction, build_nonce};
|
||||
pub use rekey::RekeyManager;
|
||||
pub use session::ChaChaSession;
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ impl RekeyManager {
|
||||
///
|
||||
/// The old key is zeroized after the new key is derived.
|
||||
/// Returns the new 32-byte symmetric key.
|
||||
///
|
||||
/// NOTE: Rekeying changes **only** the symmetric key material. Sequence
|
||||
/// numbers and timestamps in the media framing layer (e.g. `MediaHeader`)
|
||||
/// are untouched — they continue monotonically across the rekey boundary.
|
||||
pub fn perform_rekey(
|
||||
&mut self,
|
||||
new_peer_pub: &[u8; 32],
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
//! Implements the `CryptoSession` trait for per-call media encryption.
|
||||
//! Nonces are derived deterministically from session_id + sequence counter + direction.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chacha20poly1305::aead::Aead;
|
||||
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
use rand::rngs::OsRng;
|
||||
use wzp_proto::{CryptoError, CryptoSession};
|
||||
use wzp_proto::{CryptoError, CryptoSession, MediaHeader, MediaType};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use crate::anti_replay::AntiReplayWindow;
|
||||
use crate::nonce::{self, Direction};
|
||||
use crate::rekey::RekeyManager;
|
||||
|
||||
@@ -28,6 +31,10 @@ pub struct ChaChaSession {
|
||||
pending_rekey_secret: Option<StaticSecret>,
|
||||
/// Short Authentication String (4-digit code for verbal verification).
|
||||
sas_code: Option<u32>,
|
||||
/// Per-stream anti-replay windows, keyed by (stream_id, media_type).
|
||||
anti_replay: HashMap<(u8, MediaType), AntiReplayWindow>,
|
||||
/// Last timestamp seen in encrypt() — used to assert monotonicity across rekeys.
|
||||
last_encrypt_timestamp: Option<u32>,
|
||||
}
|
||||
|
||||
impl ChaChaSession {
|
||||
@@ -49,6 +56,8 @@ impl ChaChaSession {
|
||||
rekey_mgr: RekeyManager::new(shared_secret),
|
||||
pending_rekey_secret: None,
|
||||
sas_code: None,
|
||||
anti_replay: HashMap::new(),
|
||||
last_encrypt_timestamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +76,27 @@ impl ChaChaSession {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a v2 `MediaHeader` from raw bytes.
|
||||
/// Returns `None` if the buffer is too short or not a valid v2 header.
|
||||
fn parse_header(header_bytes: &[u8]) -> Option<MediaHeader> {
|
||||
if header_bytes.len() < MediaHeader::WIRE_SIZE {
|
||||
return None;
|
||||
}
|
||||
let mut cursor = std::io::Cursor::new(header_bytes);
|
||||
MediaHeader::read_from(&mut cursor)
|
||||
}
|
||||
|
||||
/// Return the default anti-replay window size for a given media type.
|
||||
fn default_window_for_media_type(media_type: MediaType) -> AntiReplayWindow {
|
||||
let size = match media_type {
|
||||
MediaType::Audio => 64,
|
||||
MediaType::Video => 1024,
|
||||
MediaType::Data => 256,
|
||||
MediaType::Control => 32,
|
||||
};
|
||||
AntiReplayWindow::with_window(size)
|
||||
}
|
||||
|
||||
impl CryptoSession for ChaChaSession {
|
||||
fn encrypt(
|
||||
&mut self,
|
||||
@@ -74,10 +104,14 @@ impl CryptoSession for ChaChaSession {
|
||||
plaintext: &[u8],
|
||||
out: &mut Vec<u8>,
|
||||
) -> Result<(), CryptoError> {
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.send_seq, Direction::Send);
|
||||
// Derive nonce from the wire-level seq in the header, not from an
|
||||
// internal counter. This ensures the receiver can reconstruct the
|
||||
// same nonce using the header it receives, regardless of delivery order.
|
||||
let header = parse_header(header_bytes)
|
||||
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
// Encrypt with AAD
|
||||
use chacha20poly1305::aead::Payload;
|
||||
let payload = Payload {
|
||||
msg: plaintext,
|
||||
@@ -90,7 +124,19 @@ impl CryptoSession for ChaChaSession {
|
||||
.map_err(|_| CryptoError::Internal("encryption failed".into()))?;
|
||||
|
||||
out.extend_from_slice(&ciphertext);
|
||||
self.send_seq = self.send_seq.wrapping_add(1);
|
||||
self.send_seq = self.send_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||
|
||||
// M5: assert timestamp_ms is non-decreasing across calls (including post-rekey).
|
||||
// Timestamps are u32 and wrap at 2^32 ms (~49 days); allow wrapping.
|
||||
debug_assert!(
|
||||
self.last_encrypt_timestamp
|
||||
.map_or(true, |last| header.timestamp.wrapping_sub(last) < u32::MAX / 2),
|
||||
"encrypt: timestamp must not decrease (last={:?}, now={})",
|
||||
self.last_encrypt_timestamp,
|
||||
header.timestamp,
|
||||
);
|
||||
self.last_encrypt_timestamp = Some(header.timestamp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -100,9 +146,14 @@ impl CryptoSession for ChaChaSession {
|
||||
ciphertext: &[u8],
|
||||
out: &mut Vec<u8>,
|
||||
) -> Result<(), CryptoError> {
|
||||
// Use Direction::Send to match the sender's nonce construction.
|
||||
// The recv_seq counter tracks which packet from the peer we're decrypting.
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.recv_seq, Direction::Send);
|
||||
// Parse header before decryption — needed for nonce derivation.
|
||||
// Using header.seq (not recv_seq) means the nonce is always derived
|
||||
// from the same wire field as the sender, surviving out-of-order delivery.
|
||||
// A recv_seq counter diverges from the sender's send_seq on any reorder,
|
||||
// causing every subsequent decryption to fail for the rest of the session.
|
||||
let header = parse_header(header_bytes)
|
||||
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
use chacha20poly1305::aead::Payload;
|
||||
@@ -116,8 +167,21 @@ impl CryptoSession for ChaChaSession {
|
||||
.decrypt(nonce, payload)
|
||||
.map_err(|_| CryptoError::DecryptionFailed)?;
|
||||
|
||||
let plaintext_len = plaintext.len();
|
||||
out.extend_from_slice(&plaintext);
|
||||
self.recv_seq = self.recv_seq.wrapping_add(1);
|
||||
self.recv_seq = self.recv_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||
|
||||
// Anti-replay check: header already parsed above.
|
||||
let window = self
|
||||
.anti_replay
|
||||
.entry((header.stream_id, header.media_type))
|
||||
.or_insert_with(|| default_window_for_media_type(header.media_type));
|
||||
if let Err(e) = window.check_and_update(header.seq) {
|
||||
// Roll back the plaintext we just appended.
|
||||
out.truncate(out.len() - plaintext_len);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -135,10 +199,14 @@ impl CryptoSession for ChaChaSession {
|
||||
.ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?;
|
||||
|
||||
let total_packets = self.send_seq as u64 + self.recv_seq as u64;
|
||||
let new_key = self.rekey_mgr.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
||||
let new_key = self
|
||||
.rekey_mgr
|
||||
.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
||||
self.install_key(new_key);
|
||||
|
||||
// Reset sequence counters after rekey for nonce uniqueness
|
||||
// Reset sequence counters after rekey for nonce uniqueness.
|
||||
// last_encrypt_timestamp is intentionally NOT reset — spec requires
|
||||
// timestamp_ms to be monotonic across rekeys.
|
||||
self.send_seq = 0;
|
||||
self.recv_seq = 0;
|
||||
|
||||
@@ -153,24 +221,42 @@ impl CryptoSession for ChaChaSession {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
fn make_session_pair() -> (ChaChaSession, ChaChaSession) {
|
||||
let key = [0x42u8; 32];
|
||||
(ChaChaSession::new(key), ChaChaSession::new(key))
|
||||
}
|
||||
|
||||
/// Build a minimal valid v2 MediaHeader serialised to bytes.
|
||||
fn make_header_bytes(seq: u32) -> Vec<u8> {
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq.wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut bytes = Vec::new();
|
||||
header.write_to(&mut bytes);
|
||||
bytes
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"test-header";
|
||||
let header = make_header_bytes(0);
|
||||
let plaintext = b"hello warzone";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||
|
||||
// Bob decrypts (his recv matches Alice's send)
|
||||
let mut decrypted = Vec::new();
|
||||
bob.decrypt(header, &ciphertext, &mut decrypted).unwrap();
|
||||
bob.decrypt(&header, &ciphertext, &mut decrypted).unwrap();
|
||||
|
||||
assert_eq!(&decrypted, plaintext);
|
||||
}
|
||||
@@ -178,14 +264,18 @@ mod tests {
|
||||
#[test]
|
||||
fn decrypt_wrong_aad_fails() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"correct-header";
|
||||
let correct_header = make_header_bytes(0);
|
||||
// Different seq → different nonce AND different AAD bytes: decryption must fail.
|
||||
let wrong_header = make_header_bytes(1);
|
||||
let plaintext = b"secret data";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice
|
||||
.encrypt(&correct_header, plaintext, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
let result = bob.decrypt(b"wrong-header", &ciphertext, &mut decrypted);
|
||||
let result = bob.decrypt(&wrong_header, &ciphertext, &mut decrypted);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -194,29 +284,29 @@ mod tests {
|
||||
let mut alice = ChaChaSession::new([0xAA; 32]);
|
||||
let mut eve = ChaChaSession::new([0xBB; 32]);
|
||||
|
||||
let header = b"hdr";
|
||||
let header = make_header_bytes(0);
|
||||
let plaintext = b"secret";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
let result = eve.decrypt(header, &ciphertext, &mut decrypted);
|
||||
let result = eve.decrypt(&header, &ciphertext, &mut decrypted);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_packets_roundtrip() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"hdr";
|
||||
|
||||
for i in 0..100 {
|
||||
for i in 0..100u32 {
|
||||
let header = make_header_bytes(i);
|
||||
let msg = format!("message {}", i);
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(header, msg.as_bytes(), &mut ct).unwrap();
|
||||
alice.encrypt(&header, msg.as_bytes(), &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(header, &ct, &mut pt).unwrap();
|
||||
bob.decrypt(&header, &ct, &mut pt).unwrap();
|
||||
assert_eq!(pt, msg.as_bytes());
|
||||
}
|
||||
}
|
||||
@@ -235,4 +325,140 @@ mod tests {
|
||||
// Session is now rekeyed - counters reset
|
||||
assert_eq!(alice.send_seq, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_survives_out_of_order_delivery() {
|
||||
// Regression test for nonce derivation using recv_seq instead of
|
||||
// MediaHeader.seq. If nonces are tied to a local counter, any reorder
|
||||
// causes the counter to diverge from the sender's seq and every
|
||||
// subsequent packet fails decryption permanently.
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let key = [0x55u8; 32];
|
||||
let mut alice = ChaChaSession::new(key);
|
||||
let mut bob = ChaChaSession::new(key);
|
||||
|
||||
let plaintext = b"audio payload";
|
||||
|
||||
// Encrypt 5 packets in order (seqs 10, 11, 12, 13, 14).
|
||||
let seqs = [10u32, 11, 12, 13, 14];
|
||||
let mut ciphertexts: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
|
||||
for &seq in &seqs {
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
ciphertexts.push((header_bytes, ct));
|
||||
}
|
||||
|
||||
// Bob receives them out of order: 0, 2, 1, 4, 3
|
||||
let delivery_order = [0usize, 2, 1, 4, 3];
|
||||
for &idx in &delivery_order {
|
||||
let (ref hdr, ref ct) = ciphertexts[idx];
|
||||
let mut pt = Vec::new();
|
||||
let result = bob.decrypt(hdr, ct, &mut pt);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"out-of-order packet (original idx={idx}, seq={}) must decrypt successfully",
|
||||
seqs[idx]
|
||||
);
|
||||
assert_eq!(&pt, plaintext);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_stream_anti_replay_rejects_duplicate() {
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 10,
|
||||
seq: 42,
|
||||
timestamp: 1000,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
|
||||
let plaintext = b"audio frame";
|
||||
|
||||
// First packet decrypts successfully
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||
assert_eq!(&pt, plaintext);
|
||||
|
||||
// Exact duplicate is rejected by anti-replay
|
||||
let mut pt2 = Vec::new();
|
||||
let result = bob.decrypt(&header_bytes, &ct, &mut pt2);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"duplicate packet with same seq must be rejected"
|
||||
);
|
||||
assert!(pt2.is_empty(), "plaintext must be rolled back on replay");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_stream_anti_replay_video_burst_200_with_reorder() {
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Video,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 1,
|
||||
fec_ratio: 10,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
};
|
||||
|
||||
let plaintext = b"video frame";
|
||||
|
||||
// Send 200 packets in order
|
||||
for i in 0..200 {
|
||||
let mut h = header;
|
||||
h.seq = i;
|
||||
let mut header_bytes = Vec::new();
|
||||
h.write_to(&mut header_bytes);
|
||||
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||
}
|
||||
|
||||
// Re-send packet 50 — should be rejected as replay
|
||||
let mut h = header;
|
||||
h.seq = 50;
|
||||
let mut header_bytes = Vec::new();
|
||||
h.write_to(&mut header_bytes);
|
||||
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
let result = bob.decrypt(&header_bytes, &ct, &mut pt);
|
||||
assert!(result.is_err(), "reordered duplicate must be rejected");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
|
||||
//! 4. Mnemonic: BIP39 interop between both implementations
|
||||
|
||||
use wzp_proto::KeyExchange;
|
||||
use wzp_proto::{KeyExchange, default_signal_version};
|
||||
|
||||
// ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
|
||||
|
||||
@@ -52,7 +52,10 @@ fn wzp_identity_module_matches_featherchat() {
|
||||
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
|
||||
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
|
||||
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
|
||||
assert_eq!(wzp_pub.fingerprint.to_string(), fc_pub.fingerprint.to_string());
|
||||
assert_eq!(
|
||||
wzp_pub.fingerprint.to_string(),
|
||||
fc_pub.fingerprint.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -111,11 +114,15 @@ fn mnemonic_strings_identical() {
|
||||
fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
// WZP creates a CallOffer SignalMessage
|
||||
let offer = wzp_proto::SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![],
|
||||
};
|
||||
|
||||
// Encode as featherChat CallSignal payload
|
||||
@@ -148,16 +155,25 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
// And deserializes back
|
||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
|
||||
if let warzone_protocol::message::WireMessage::CallSignal {
|
||||
id, payload: p, signal_type, ..
|
||||
id,
|
||||
payload: p,
|
||||
signal_type,
|
||||
..
|
||||
} = decoded
|
||||
{
|
||||
assert_eq!(id, "call-123");
|
||||
assert!(matches!(signal_type, warzone_protocol::message::CallSignalType::Offer));
|
||||
assert!(matches!(
|
||||
signal_type,
|
||||
warzone_protocol::message::CallSignalType::Offer
|
||||
));
|
||||
|
||||
// Decode the WZP payload back
|
||||
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
|
||||
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
|
||||
assert!(matches!(wzp_payload.signal, wzp_proto::SignalMessage::CallOffer { .. }));
|
||||
assert!(matches!(
|
||||
wzp_payload.signal,
|
||||
wzp_proto::SignalMessage::CallOffer { .. }
|
||||
));
|
||||
} else {
|
||||
panic!("expected CallSignal");
|
||||
}
|
||||
@@ -166,10 +182,12 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
#[test]
|
||||
fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||
let answer = wzp_proto::SignalMessage::CallAnswer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [10u8; 32],
|
||||
ephemeral_pub: [20u8; 32],
|
||||
signature: vec![30u8; 64],
|
||||
chosen_profile: wzp_proto::QualityProfile::DEGRADED,
|
||||
video_codec: None,
|
||||
};
|
||||
|
||||
let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None);
|
||||
@@ -198,13 +216,17 @@ fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||
#[test]
|
||||
fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
|
||||
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
|
||||
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
|
||||
assert!(matches!(signal_type, wzp_client::featherchat::CallSignalType::Hangup));
|
||||
assert!(matches!(
|
||||
signal_type,
|
||||
wzp_client::featherchat::CallSignalType::Hangup
|
||||
));
|
||||
|
||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||
id: "call-789".to_string(),
|
||||
@@ -219,7 +241,10 @@ fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||
|
||||
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
||||
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
||||
assert!(matches!(wzp.signal, wzp_proto::SignalMessage::Hangup { .. }));
|
||||
assert!(matches!(
|
||||
wzp.signal,
|
||||
wzp_proto::SignalMessage::Hangup { .. }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,8 +277,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
||||
"eth_address": null
|
||||
});
|
||||
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||
assert!(wzp_resp.valid);
|
||||
assert_eq!(
|
||||
wzp_resp.fingerprint.unwrap(),
|
||||
@@ -265,8 +289,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
||||
#[test]
|
||||
fn auth_invalid_response_matches() {
|
||||
let fc_response = serde_json::json!({ "valid": false });
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||
assert!(!wzp_resp.valid);
|
||||
assert!(wzp_resp.fingerprint.is_none());
|
||||
}
|
||||
@@ -280,28 +303,39 @@ fn all_signal_types_map_correctly() {
|
||||
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
|
||||
(
|
||||
wzp_proto::SignalMessage::CallOffer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
signature: vec![], supported_profiles: vec![],
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
supported_profiles: vec![],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
video_codecs: vec![],
|
||||
},
|
||||
"Offer",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::CallAnswer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
chosen_profile: wzp_proto::QualityProfile::GOOD,
|
||||
video_codec: None,
|
||||
},
|
||||
"Answer",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::IceCandidate {
|
||||
version: default_signal_version(),
|
||||
candidate: "candidate:1".to_string(),
|
||||
},
|
||||
"IceCandidate",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
},
|
||||
@@ -312,7 +346,10 @@ fn all_signal_types_map_correctly() {
|
||||
for (signal, expected_name) in cases {
|
||||
let ct = signal_to_call_type(&signal);
|
||||
let name = format!("{ct:?}");
|
||||
assert_eq!(name, expected_name, "signal type mapping for {expected_name}");
|
||||
assert_eq!(
|
||||
name, expected_name,
|
||||
"signal type mapping for {expected_name}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,8 +463,7 @@ fn auth_response_with_eth_address() {
|
||||
"alias": "vitalik",
|
||||
"eth_address": "0x1234567890abcdef1234567890abcdef12345678"
|
||||
});
|
||||
let resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_eth).unwrap();
|
||||
let resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_eth).unwrap();
|
||||
assert!(resp.valid);
|
||||
assert_eq!(
|
||||
resp.fingerprint.unwrap(),
|
||||
@@ -442,8 +478,7 @@ fn auth_response_with_eth_address() {
|
||||
"alias": "anon",
|
||||
"eth_address": null
|
||||
});
|
||||
let resp2: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_null_eth).unwrap();
|
||||
let resp2: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_null_eth).unwrap();
|
||||
assert!(resp2.valid);
|
||||
assert_eq!(
|
||||
resp2.fingerprint.unwrap(),
|
||||
@@ -454,15 +489,15 @@ fn auth_response_with_eth_address() {
|
||||
let without_eth = serde_json::json!({
|
||||
"valid": false
|
||||
});
|
||||
let resp3: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(without_eth).unwrap();
|
||||
let resp3: wzp_relay::auth::ValidateResponse = serde_json::from_value(without_eth).unwrap();
|
||||
assert!(!resp3.valid);
|
||||
}
|
||||
|
||||
/// WZP-S-7: SignalMessage::AuthToken { token } exists and round-trips via serde.
|
||||
/// WZP-S-7: SignalMessage::AuthToken { version: default_signal_version(), token } exists and round-trips via serde.
|
||||
#[test]
|
||||
fn wzp_proto_has_auth_token_variant() {
|
||||
let msg = wzp_proto::SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: "fc-bearer-token-xyz".to_string(),
|
||||
};
|
||||
|
||||
@@ -473,7 +508,7 @@ fn wzp_proto_has_auth_token_variant() {
|
||||
|
||||
// Deserialize back
|
||||
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
if let wzp_proto::SignalMessage::AuthToken { token } = decoded {
|
||||
if let wzp_proto::SignalMessage::AuthToken { token, .. } = decoded {
|
||||
assert_eq!(token, "fc-bearer-token-xyz");
|
||||
} else {
|
||||
panic!("expected AuthToken variant, got: {decoded:?}");
|
||||
@@ -496,7 +531,11 @@ fn all_fc_call_signal_types_representable() {
|
||||
(CallSignalType::Busy, "Busy"),
|
||||
];
|
||||
|
||||
assert_eq!(variants.len(), 7, "featherChat defines exactly 7 call signal types");
|
||||
assert_eq!(
|
||||
variants.len(),
|
||||
7,
|
||||
"featherChat defines exactly 7 call signal types"
|
||||
);
|
||||
|
||||
for (variant, expected_name) in &variants {
|
||||
let name = format!("{variant:?}");
|
||||
@@ -550,10 +589,7 @@ fn hash_room_name_used_as_sni_is_valid() {
|
||||
#[test]
|
||||
fn wzp_proto_cargo_toml_is_standalone() {
|
||||
// Try both paths (run from workspace root or from crate directory)
|
||||
let candidates = [
|
||||
"crates/wzp-proto/Cargo.toml",
|
||||
"../wzp-proto/Cargo.toml",
|
||||
];
|
||||
let candidates = ["crates/wzp-proto/Cargo.toml", "../wzp-proto/Cargo.toml"];
|
||||
|
||||
let contents = candidates
|
||||
.iter()
|
||||
|
||||
@@ -13,11 +13,17 @@ pub struct AdaptiveFec {
|
||||
pub repair_ratio: f32,
|
||||
/// Symbol size in bytes.
|
||||
pub symbol_size: u16,
|
||||
/// Repair ratio to use when the block contains a keyframe.
|
||||
/// Default 0.5 (50% overhead) — keyframes are critical and worth
|
||||
/// the extra bandwidth.
|
||||
pub keyframe_repair_ratio: f32,
|
||||
}
|
||||
|
||||
impl AdaptiveFec {
|
||||
/// Default symbol size for adaptive configuration.
|
||||
const DEFAULT_SYMBOL_SIZE: u16 = 256;
|
||||
/// Default keyframe repair ratio (PRD-video-v1 T4.5).
|
||||
const DEFAULT_KEYFRAME_REPAIR_RATIO: f32 = 0.5;
|
||||
|
||||
/// Create an adaptive FEC configuration from a quality profile.
|
||||
///
|
||||
@@ -30,12 +36,15 @@ impl AdaptiveFec {
|
||||
frames_per_block: profile.frames_per_block as usize,
|
||||
repair_ratio: profile.fec_ratio,
|
||||
symbol_size: Self::DEFAULT_SYMBOL_SIZE,
|
||||
keyframe_repair_ratio: Self::DEFAULT_KEYFRAME_REPAIR_RATIO,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a configured FEC encoder from this adaptive configuration.
|
||||
pub fn build_encoder(&self) -> RaptorQFecEncoder {
|
||||
RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size)
|
||||
let mut enc = RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size);
|
||||
enc.set_keyframe_ratio(self.keyframe_repair_ratio);
|
||||
enc
|
||||
}
|
||||
|
||||
/// Get the repair ratio for use with `FecEncoder::generate_repair()`.
|
||||
@@ -59,6 +68,7 @@ mod tests {
|
||||
let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD);
|
||||
assert_eq!(cfg.frames_per_block, 5);
|
||||
assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON);
|
||||
assert!((cfg.keyframe_repair_ratio - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -29,9 +29,9 @@ pub enum DecoderBlockState {
|
||||
/// Manages encoder-side block tracking.
|
||||
pub struct EncoderBlockManager {
|
||||
/// Current block ID being built.
|
||||
current_id: u8,
|
||||
current_id: u16,
|
||||
/// State of known blocks.
|
||||
blocks: HashMap<u8, EncoderBlockState>,
|
||||
blocks: HashMap<u16, EncoderBlockState>,
|
||||
}
|
||||
|
||||
impl EncoderBlockManager {
|
||||
@@ -45,7 +45,7 @@ impl EncoderBlockManager {
|
||||
}
|
||||
|
||||
/// Get the next block ID (advances the current building block).
|
||||
pub fn next_block_id(&mut self) -> u8 {
|
||||
pub fn next_block_id(&mut self) -> u16 {
|
||||
let old = self.current_id;
|
||||
// Mark old block as pending.
|
||||
self.blocks.insert(old, EncoderBlockState::Pending);
|
||||
@@ -57,23 +57,23 @@ impl EncoderBlockManager {
|
||||
}
|
||||
|
||||
/// Current block ID being built.
|
||||
pub fn current_id(&self) -> u8 {
|
||||
pub fn current_id(&self) -> u16 {
|
||||
self.current_id
|
||||
}
|
||||
|
||||
/// Mark a block as fully sent.
|
||||
pub fn mark_sent(&mut self, block_id: u8) {
|
||||
pub fn mark_sent(&mut self, block_id: u16) {
|
||||
self.blocks.insert(block_id, EncoderBlockState::Sent);
|
||||
}
|
||||
|
||||
/// Mark a block as acknowledged by the peer.
|
||||
pub fn mark_acknowledged(&mut self, block_id: u8) {
|
||||
pub fn mark_acknowledged(&mut self, block_id: u16) {
|
||||
self.blocks
|
||||
.insert(block_id, EncoderBlockState::Acknowledged);
|
||||
}
|
||||
|
||||
/// Get the state of a block.
|
||||
pub fn state(&self, block_id: u8) -> Option<EncoderBlockState> {
|
||||
pub fn state(&self, block_id: u16) -> Option<EncoderBlockState> {
|
||||
self.blocks.get(&block_id).copied()
|
||||
}
|
||||
|
||||
@@ -93,9 +93,9 @@ impl Default for EncoderBlockManager {
|
||||
/// Manages decoder-side block tracking.
|
||||
pub struct DecoderBlockManager {
|
||||
/// State of known blocks.
|
||||
blocks: HashMap<u8, DecoderBlockState>,
|
||||
blocks: HashMap<u16, DecoderBlockState>,
|
||||
/// Set of completed block IDs.
|
||||
completed: HashSet<u8>,
|
||||
completed: HashSet<u16>,
|
||||
}
|
||||
|
||||
impl DecoderBlockManager {
|
||||
@@ -107,43 +107,43 @@ impl DecoderBlockManager {
|
||||
}
|
||||
|
||||
/// Register that we are receiving symbols for a block.
|
||||
pub fn touch(&mut self, block_id: u8) {
|
||||
pub fn touch(&mut self, block_id: u16) {
|
||||
self.blocks
|
||||
.entry(block_id)
|
||||
.or_insert(DecoderBlockState::Assembling);
|
||||
}
|
||||
|
||||
/// Mark a block as successfully decoded.
|
||||
pub fn mark_complete(&mut self, block_id: u8) {
|
||||
pub fn mark_complete(&mut self, block_id: u16) {
|
||||
self.blocks.insert(block_id, DecoderBlockState::Complete);
|
||||
self.completed.insert(block_id);
|
||||
}
|
||||
|
||||
/// Mark a block as expired.
|
||||
pub fn mark_expired(&mut self, block_id: u8) {
|
||||
pub fn mark_expired(&mut self, block_id: u16) {
|
||||
self.blocks.insert(block_id, DecoderBlockState::Expired);
|
||||
self.completed.remove(&block_id);
|
||||
}
|
||||
|
||||
/// Check if a block has been fully decoded.
|
||||
pub fn is_block_complete(&self, block_id: u8) -> bool {
|
||||
pub fn is_block_complete(&self, block_id: u16) -> bool {
|
||||
self.completed.contains(&block_id)
|
||||
}
|
||||
|
||||
/// Get the state of a block.
|
||||
pub fn state(&self, block_id: u8) -> Option<DecoderBlockState> {
|
||||
pub fn state(&self, block_id: u16) -> Option<DecoderBlockState> {
|
||||
self.blocks.get(&block_id).copied()
|
||||
}
|
||||
|
||||
/// Expire all blocks older than the given block_id (using wrapping distance).
|
||||
pub fn expire_before(&mut self, block_id: u8) {
|
||||
let to_expire: Vec<u8> = self
|
||||
pub fn expire_before(&mut self, block_id: u16) {
|
||||
let to_expire: Vec<u16> = self
|
||||
.blocks
|
||||
.keys()
|
||||
.copied()
|
||||
.filter(|&id| {
|
||||
let distance = block_id.wrapping_sub(id);
|
||||
distance > 0 && distance <= 128
|
||||
distance > 0 && distance <= 32768
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -207,7 +207,7 @@ mod tests {
|
||||
#[test]
|
||||
fn decoder_expire_before() {
|
||||
let mut mgr = DecoderBlockManager::new();
|
||||
for i in 0..5u8 {
|
||||
for i in 0..5u16 {
|
||||
mgr.touch(i);
|
||||
}
|
||||
mgr.mark_complete(1);
|
||||
@@ -231,11 +231,11 @@ mod tests {
|
||||
#[test]
|
||||
fn next_block_id_wraps() {
|
||||
let mut mgr = EncoderBlockManager::new();
|
||||
// Start at 0, advance to 255 then wrap
|
||||
for _ in 0..255 {
|
||||
// Start at 0, advance to u16::MAX then wrap
|
||||
for _ in 0..65535 {
|
||||
mgr.next_block_id();
|
||||
}
|
||||
assert_eq!(mgr.current_id(), 255);
|
||||
assert_eq!(mgr.current_id(), u16::MAX);
|
||||
let next = mgr.next_block_id();
|
||||
assert_eq!(next, 0);
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder};
|
||||
use wzp_proto::error::FecError;
|
||||
use wzp_proto::FecDecoder;
|
||||
use wzp_proto::error::FecError;
|
||||
|
||||
/// Length prefix size (u16 little-endian), must match encoder.
|
||||
const LEN_PREFIX: usize = 2;
|
||||
@@ -32,7 +32,7 @@ struct BlockState {
|
||||
/// RaptorQ-based FEC decoder that handles multiple concurrent blocks.
|
||||
pub struct RaptorQFecDecoder {
|
||||
/// Per-block decoder state, keyed by block_id.
|
||||
blocks: HashMap<u8, BlockState>,
|
||||
blocks: HashMap<u16, BlockState>,
|
||||
/// Symbol size (must match encoder).
|
||||
symbol_size: u16,
|
||||
/// Number of source symbols per block (from encoder config).
|
||||
@@ -57,7 +57,7 @@ impl RaptorQFecDecoder {
|
||||
Self::new(frames_per_block, 256)
|
||||
}
|
||||
|
||||
fn get_or_create_block(&mut self, block_id: u8) -> &mut BlockState {
|
||||
fn get_or_create_block(&mut self, block_id: u16) -> &mut BlockState {
|
||||
self.blocks.entry(block_id).or_insert_with(|| BlockState {
|
||||
num_source_symbols: Some(self.frames_per_block),
|
||||
packets: Vec::new(),
|
||||
@@ -72,8 +72,8 @@ impl RaptorQFecDecoder {
|
||||
impl FecDecoder for RaptorQFecDecoder {
|
||||
fn add_symbol(
|
||||
&mut self,
|
||||
block_id: u8,
|
||||
symbol_index: u8,
|
||||
block_id: u16,
|
||||
symbol_index: u16,
|
||||
_is_repair: bool,
|
||||
data: &[u8],
|
||||
) -> Result<(), FecError> {
|
||||
@@ -104,13 +104,13 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
padded[..len].copy_from_slice(&data[..len]);
|
||||
|
||||
let esi = symbol_index as u32;
|
||||
let packet = EncodingPacket::new(PayloadId::new(block_id, esi), padded);
|
||||
let packet = EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, esi), padded);
|
||||
block.packets.push(packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError> {
|
||||
fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError> {
|
||||
let frames_per_block = self.frames_per_block;
|
||||
let block = match self.blocks.get_mut(&block_id) {
|
||||
Some(b) => b,
|
||||
@@ -125,7 +125,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
let block_length = (num_source as u64) * (block.symbol_size as u64);
|
||||
|
||||
let config = ObjectTransmissionInformation::with_defaults(block_length, block.symbol_size);
|
||||
let mut decoder = SourceBlockDecoder::new(block_id, &config, block_length);
|
||||
let mut decoder = SourceBlockDecoder::new((block_id & 0xFF) as u8, &config, block_length);
|
||||
|
||||
let decoded = decoder.decode(block.packets.clone());
|
||||
|
||||
@@ -140,10 +140,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
frames.push(Vec::new());
|
||||
continue;
|
||||
}
|
||||
let payload_len = u16::from_le_bytes([
|
||||
data[offset],
|
||||
data[offset + 1],
|
||||
]) as usize;
|
||||
let payload_len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
|
||||
let payload_start = offset + LEN_PREFIX;
|
||||
let payload_end = (payload_start + payload_len).min(data.len());
|
||||
frames.push(data[payload_start..payload_end].to_vec());
|
||||
@@ -159,15 +156,15 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
fn expire_before(&mut self, block_id: u8) {
|
||||
fn expire_before(&mut self, block_id: u16) {
|
||||
// Remove blocks with IDs "older" than block_id.
|
||||
// With wrapping u8 IDs, we consider a block old if its distance
|
||||
// (in the forward direction) to block_id is > 128.
|
||||
// With wrapping u16 IDs, we consider a block old if its distance
|
||||
// (in the forward direction) to block_id is > 32768.
|
||||
self.blocks.retain(|&id, _| {
|
||||
let distance = block_id.wrapping_sub(id);
|
||||
// If distance is 0 or > 128, the block is current or "ahead" — keep it.
|
||||
// If distance is 1..=128, the block is behind — remove it.
|
||||
distance == 0 || distance > 128
|
||||
// If distance is 0 or > 32768, the block is current or "ahead" — keep it.
|
||||
// If distance is 1..=32768, the block is behind — remove it.
|
||||
distance == 0 || distance > 32768
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -198,9 +195,7 @@ mod tests {
|
||||
|
||||
// Feed all source symbols (using the length-prefixed padded data).
|
||||
for (i, pkt) in source_pkts.iter().enumerate() {
|
||||
decoder
|
||||
.add_symbol(0, i as u8, false, pkt.data())
|
||||
.unwrap();
|
||||
decoder.add_symbol(0, i as u16, false, pkt.data()).unwrap();
|
||||
}
|
||||
|
||||
let result = decoder.try_decode(0).unwrap();
|
||||
@@ -233,7 +228,11 @@ mod tests {
|
||||
let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1);
|
||||
let mut dec = SourceBlockDecoder::new(0, &config, block_len);
|
||||
let decoded = dec.decode(all);
|
||||
assert!(decoded.is_some(), "Should recover with {:.0}% loss", drop_fraction * 100.0);
|
||||
assert!(
|
||||
decoded.is_some(),
|
||||
"Should recover with {:.0}% loss",
|
||||
drop_fraction * 100.0
|
||||
);
|
||||
|
||||
let data = decoded.unwrap();
|
||||
let ss = SYMBOL_SIZE as usize;
|
||||
@@ -245,22 +244,28 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_30pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3); }
|
||||
fn decode_with_30pct_loss() {
|
||||
run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_50pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5); }
|
||||
fn decode_with_50pct_loss() {
|
||||
run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_70pct_source_loss_heavy_repair() { run_loss_test(8, 2.0, 0.5); }
|
||||
fn decode_with_70pct_source_loss_heavy_repair() {
|
||||
run_loss_test(8, 2.0, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expire_removes_old_blocks() {
|
||||
let mut decoder = RaptorQFecDecoder::new(FRAMES_PER_BLOCK, SYMBOL_SIZE);
|
||||
|
||||
// Add symbols to blocks 0, 1, 2
|
||||
for block_id in 0..3u8 {
|
||||
for block_id in 0..3u16 {
|
||||
decoder
|
||||
.add_symbol(block_id, 0, false, &[block_id; 50])
|
||||
.add_symbol(block_id, 0, false, &[block_id as u8; 50])
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
@@ -288,10 +293,10 @@ mod tests {
|
||||
// Interleave symbols from block 0 and block 1
|
||||
for i in 0..FRAMES_PER_BLOCK {
|
||||
decoder
|
||||
.add_symbol(0, i as u8, false, pkts_a[i].data())
|
||||
.add_symbol(0, i as u16, false, pkts_a[i].data())
|
||||
.unwrap();
|
||||
decoder
|
||||
.add_symbol(1, i as u8, false, pkts_b[i].data())
|
||||
.add_symbol(1, i as u16, false, pkts_b[i].data())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols.
|
||||
|
||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder};
|
||||
use wzp_proto::error::FecError;
|
||||
use wzp_proto::FecEncoder;
|
||||
use wzp_proto::error::FecError;
|
||||
|
||||
/// Maximum symbol size in bytes. Audio frames are typically < 200 bytes,
|
||||
/// but we pad to a uniform size within a block.
|
||||
@@ -15,14 +15,19 @@ const LEN_PREFIX: usize = 2;
|
||||
/// RaptorQ-based FEC encoder that groups audio frames into blocks
|
||||
/// and generates fountain-code repair symbols.
|
||||
pub struct RaptorQFecEncoder {
|
||||
/// Current block ID (wraps at u8).
|
||||
block_id: u8,
|
||||
/// Current block ID (wraps at u16).
|
||||
block_id: u16,
|
||||
/// Maximum source symbols per block.
|
||||
frames_per_block: usize,
|
||||
/// Accumulated source symbols for the current block.
|
||||
source_symbols: Vec<Vec<u8>>,
|
||||
/// Symbol size used for encoding (all symbols padded to this size).
|
||||
symbol_size: u16,
|
||||
/// True if at least one source symbol in the current block is a keyframe.
|
||||
has_keyframe: bool,
|
||||
/// Repair ratio to use when the block contains a keyframe.
|
||||
/// If zero, the nominal ratio passed to [`generate_repair`] is used.
|
||||
keyframe_ratio: f32,
|
||||
}
|
||||
|
||||
impl RaptorQFecEncoder {
|
||||
@@ -36,9 +41,26 @@ impl RaptorQFecEncoder {
|
||||
frames_per_block,
|
||||
source_symbols: Vec::with_capacity(frames_per_block),
|
||||
symbol_size,
|
||||
has_keyframe: false,
|
||||
keyframe_ratio: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the repair ratio to use for blocks that contain at least one
|
||||
/// keyframe source symbol.
|
||||
///
|
||||
/// When `keyframe_ratio > 0.0` and [`has_keyframe`](Self::has_keyframe)
|
||||
/// is true, [`generate_repair`](FecEncoder::generate_repair) uses this
|
||||
/// ratio instead of the nominal ratio passed by the caller.
|
||||
pub fn set_keyframe_ratio(&mut self, ratio: f32) {
|
||||
self.keyframe_ratio = ratio.max(0.0);
|
||||
}
|
||||
|
||||
/// Returns true if the current block contains a keyframe source symbol.
|
||||
pub fn has_keyframe(&self) -> bool {
|
||||
self.has_keyframe
|
||||
}
|
||||
|
||||
/// Create with default symbol size (256 bytes).
|
||||
pub fn with_defaults(frames_per_block: usize) -> Self {
|
||||
Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE)
|
||||
@@ -54,8 +76,7 @@ impl RaptorQFecEncoder {
|
||||
let payload_len = sym.len().min(max_payload);
|
||||
let offset = i * ss;
|
||||
// Write 2-byte little-endian length prefix.
|
||||
data[offset..offset + LEN_PREFIX]
|
||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
// Write payload after prefix.
|
||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||
.copy_from_slice(&sym[..payload_len]);
|
||||
@@ -75,17 +96,36 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError> {
|
||||
fn add_source_symbol_with_keyframe(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
is_keyframe: bool,
|
||||
) -> Result<(), FecError> {
|
||||
self.add_source_symbol(data)?;
|
||||
if is_keyframe {
|
||||
self.has_keyframe = true;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError> {
|
||||
if self.source_symbols.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let effective_ratio = if self.has_keyframe && self.keyframe_ratio > 0.0 {
|
||||
self.keyframe_ratio
|
||||
} else {
|
||||
ratio
|
||||
};
|
||||
|
||||
let block_data = self.build_block_data();
|
||||
let config = ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
||||
let encoder = SourceBlockEncoder::new(self.block_id, &config, &block_data);
|
||||
let config =
|
||||
ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
||||
let encoder = SourceBlockEncoder::new((self.block_id & 0xFF) as u8, &config, &block_data);
|
||||
|
||||
let num_source = self.source_symbols.len() as u32;
|
||||
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
||||
let num_repair = ((num_source as f32) * effective_ratio).ceil() as u32;
|
||||
if num_repair == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
@@ -93,11 +133,11 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
// Generate repair packets starting from offset 0 (ESIs begin at num_source).
|
||||
let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair);
|
||||
|
||||
let result: Vec<(u8, Vec<u8>)> = repair_packets
|
||||
let result: Vec<(u16, Vec<u8>)> = repair_packets
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, pkt): (usize, EncodingPacket)| {
|
||||
let idx = (num_source as u8).wrapping_add(i as u8);
|
||||
let idx = (num_source as u16).wrapping_add(i as u16);
|
||||
(idx, pkt.data().to_vec())
|
||||
})
|
||||
.collect();
|
||||
@@ -105,14 +145,15 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn finalize_block(&mut self) -> Result<u8, FecError> {
|
||||
fn finalize_block(&mut self) -> Result<u16, FecError> {
|
||||
let completed = self.block_id;
|
||||
self.block_id = self.block_id.wrapping_add(1);
|
||||
self.source_symbols.clear();
|
||||
self.has_keyframe = false;
|
||||
Ok(completed)
|
||||
}
|
||||
|
||||
fn current_block_id(&self) -> u8 {
|
||||
fn current_block_id(&self) -> u16 {
|
||||
self.block_id
|
||||
}
|
||||
|
||||
@@ -130,8 +171,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
|
||||
let max_payload = ss - LEN_PREFIX;
|
||||
let payload_len = sym.len().min(max_payload);
|
||||
let offset = i * ss;
|
||||
data[offset..offset + LEN_PREFIX]
|
||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||
.copy_from_slice(&sym[..payload_len]);
|
||||
}
|
||||
@@ -141,7 +181,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
|
||||
/// Helper: build source `EncodingPacket`s for a given block. Useful for
|
||||
/// the decoder tests and interleaving.
|
||||
pub fn source_packets_for_block(
|
||||
block_id: u8,
|
||||
block_id: u16,
|
||||
symbols: &[Vec<u8>],
|
||||
symbol_size: u16,
|
||||
) -> Vec<EncodingPacket> {
|
||||
@@ -151,21 +191,21 @@ pub fn source_packets_for_block(
|
||||
.map(|i| {
|
||||
let offset = i * ss;
|
||||
let sym_data = data[offset..offset + ss].to_vec();
|
||||
EncodingPacket::new(PayloadId::new(block_id, i as u32), sym_data)
|
||||
EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, i as u32), sym_data)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Helper: generate repair packets for the given source symbols.
|
||||
pub fn repair_packets_for_block(
|
||||
block_id: u8,
|
||||
block_id: u16,
|
||||
symbols: &[Vec<u8>],
|
||||
symbol_size: u16,
|
||||
ratio: f32,
|
||||
) -> Vec<EncodingPacket> {
|
||||
let data = build_prefixed_block_data(symbols, symbol_size);
|
||||
let config = ObjectTransmissionInformation::with_defaults(data.len() as u64, symbol_size);
|
||||
let encoder = SourceBlockEncoder::new(block_id, &config, &data);
|
||||
let encoder = SourceBlockEncoder::new((block_id & 0xFF) as u8, &config, &data);
|
||||
let num_source = symbols.len() as u32;
|
||||
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
||||
encoder.repair_packets(0, num_repair)
|
||||
@@ -201,14 +241,70 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_id_wraps() {
|
||||
fn block_id_wraps_u16() {
|
||||
let mut enc = RaptorQFecEncoder::with_defaults(1);
|
||||
for expected in 0..=255u8 {
|
||||
// Advance 300 blocks and verify no panic + monotonic increment.
|
||||
for expected in 0..300u16 {
|
||||
assert_eq!(enc.current_block_id(), expected);
|
||||
enc.add_source_symbol(&[expected; 10]).unwrap();
|
||||
enc.add_source_symbol(&[0u8; 10]).unwrap();
|
||||
enc.finalize_block().unwrap();
|
||||
}
|
||||
// After 256 blocks, wraps back to 0
|
||||
assert_eq!(enc.current_block_id(), 0);
|
||||
// Explicitly test wrap at u16 boundary.
|
||||
let mut enc2 = RaptorQFecEncoder::with_defaults(1);
|
||||
enc2.block_id = u16::MAX;
|
||||
enc2.add_source_symbol(&[0u8; 10]).unwrap();
|
||||
let id = enc2.finalize_block().unwrap();
|
||||
assert_eq!(id, u16::MAX);
|
||||
assert_eq!(enc2.current_block_id(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyframe_boost_uses_higher_ratio() {
|
||||
// Non-keyframe block with nominal ratio 0.2 → ceil(5 * 0.2) = 1 repair.
|
||||
let mut enc_normal = RaptorQFecEncoder::with_defaults(5);
|
||||
enc_normal.set_keyframe_ratio(0.8);
|
||||
for i in 0..5 {
|
||||
enc_normal
|
||||
.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||
.unwrap();
|
||||
}
|
||||
let normal_repair = enc_normal.generate_repair(0.2).unwrap();
|
||||
assert_eq!(normal_repair.len(), 1);
|
||||
|
||||
// Keyframe block with same nominal ratio but boost to 0.8 → ceil(5 * 0.8) = 4 repairs.
|
||||
let mut enc_key = RaptorQFecEncoder::with_defaults(5);
|
||||
enc_key.set_keyframe_ratio(0.8);
|
||||
for i in 0..5 {
|
||||
enc_key
|
||||
.add_source_symbol_with_keyframe(&[i as u8; 100], i == 2)
|
||||
.unwrap();
|
||||
}
|
||||
let keyframe_repair = enc_key.generate_repair(0.2).unwrap();
|
||||
assert_eq!(keyframe_repair.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_keyframe_block_uses_nominal_ratio() {
|
||||
let mut enc = RaptorQFecEncoder::with_defaults(5);
|
||||
enc.set_keyframe_ratio(0.8);
|
||||
|
||||
for i in 0..5 {
|
||||
enc.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let repair = enc.generate_repair(0.2).unwrap();
|
||||
assert_eq!(repair.len(), 1); // ceil(5 * 0.2) = 1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finalize_clears_keyframe_flag() {
|
||||
let mut enc = RaptorQFecEncoder::with_defaults(2);
|
||||
enc.add_source_symbol_with_keyframe(&[0u8; 10], true)
|
||||
.unwrap();
|
||||
assert!(enc.has_keyframe());
|
||||
|
||||
enc.finalize_block().unwrap();
|
||||
assert!(!enc.has_keyframe());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! rather than one block fatally.
|
||||
|
||||
/// A symbol ready for transmission: (block_id, symbol_index, is_repair, data).
|
||||
pub type Symbol = (u8, u8, bool, Vec<u8>);
|
||||
pub type Symbol = (u16, u16, bool, Vec<u8>);
|
||||
|
||||
/// Temporal interleaver that mixes symbols across multiple FEC blocks.
|
||||
pub struct Interleaver {
|
||||
@@ -64,13 +64,13 @@ mod tests {
|
||||
let interleaver = Interleaver::with_default_depth();
|
||||
|
||||
let block_a: Vec<Symbol> = (0..3)
|
||||
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8]))
|
||||
.map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
|
||||
.collect();
|
||||
let block_b: Vec<Symbol> = (0..3)
|
||||
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8]))
|
||||
.map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
|
||||
.collect();
|
||||
let block_c: Vec<Symbol> = (0..3)
|
||||
.map(|i| (2u8, i as u8, false, vec![0xC0 + i as u8]))
|
||||
.map(|i| (2u16, i as u16, false, vec![0xC0 + i as u8]))
|
||||
.collect();
|
||||
|
||||
let result = interleaver.interleave(&[block_a, block_b, block_c]);
|
||||
@@ -96,10 +96,10 @@ mod tests {
|
||||
let interleaver = Interleaver::new(2);
|
||||
|
||||
let block_a: Vec<Symbol> = (0..3)
|
||||
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8]))
|
||||
.map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
|
||||
.collect();
|
||||
let block_b: Vec<Symbol> = (0..1)
|
||||
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8]))
|
||||
.map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
|
||||
.collect();
|
||||
|
||||
let result = interleaver.interleave(&[block_a, block_b]);
|
||||
@@ -128,7 +128,7 @@ mod tests {
|
||||
let blocks: Vec<Vec<Symbol>> = (0..3)
|
||||
.map(|b| {
|
||||
(0..6)
|
||||
.map(|i| (b as u8, i as u8, false, vec![b as u8; 10]))
|
||||
.map(|i| (b as u16, i as u16, false, vec![b as u8; 10]))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
@@ -146,7 +146,10 @@ mod tests {
|
||||
|
||||
// Each block should lose exactly 2 (6 losses / 3 blocks)
|
||||
for &loss in &losses_per_block {
|
||||
assert_eq!(loss, 2, "Each block should lose at most 2 symbols from a burst of 6");
|
||||
assert_eq!(
|
||||
loss, 2,
|
||||
"Each block should lose at most 2 symbols from a burst of 6"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ pub mod encoder;
|
||||
pub mod interleave;
|
||||
|
||||
pub use adaptive::AdaptiveFec;
|
||||
pub use block_manager::{DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState};
|
||||
pub use block_manager::{
|
||||
DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState,
|
||||
};
|
||||
pub use decoder::RaptorQFecDecoder;
|
||||
pub use encoder::RaptorQFecEncoder;
|
||||
pub use interleave::Interleaver;
|
||||
@@ -24,9 +26,7 @@ pub use interleave::Interleaver;
|
||||
pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile};
|
||||
|
||||
/// Create an encoder/decoder pair configured for the given quality profile.
|
||||
pub fn create_fec_pair(
|
||||
profile: &QualityProfile,
|
||||
) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
||||
pub fn create_fec_pair(profile: &QualityProfile) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
||||
let cfg = AdaptiveFec::from_profile(profile);
|
||||
let encoder = cfg.build_encoder();
|
||||
let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size);
|
||||
|
||||
@@ -24,7 +24,10 @@ fn main() {
|
||||
let oboe_dir = fetch_oboe();
|
||||
match oboe_dir {
|
||||
Some(oboe_path) => {
|
||||
println!("cargo:warning=wzp-native: building with Oboe from {:?}", oboe_path);
|
||||
println!(
|
||||
"cargo:warning=wzp-native: building with Oboe from {:?}",
|
||||
oboe_path
|
||||
);
|
||||
let mut build = cc::Build::new();
|
||||
build
|
||||
.cpp(true)
|
||||
@@ -96,7 +99,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
let oboe_dir = out_dir.join("oboe");
|
||||
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
return Some(oboe_dir);
|
||||
}
|
||||
|
||||
@@ -111,7 +119,14 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
.status();
|
||||
|
||||
match status {
|
||||
Ok(s) if s.success() && oboe_dir.join("include").join("oboe").join("Oboe.h").exists() => {
|
||||
Ok(s)
|
||||
if s.success()
|
||||
&& oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists() =>
|
||||
{
|
||||
Some(oboe_dir)
|
||||
}
|
||||
_ => None,
|
||||
|
||||
@@ -404,12 +404,14 @@ int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
{
|
||||
auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(2000);
|
||||
int poll_count = 0;
|
||||
bool streams_started = false;
|
||||
while (std::chrono::steady_clock::now() < deadline) {
|
||||
auto cap_state = g_capture_stream->getState();
|
||||
auto play_state = g_playout_stream->getState();
|
||||
if (cap_state == oboe::StreamState::Started &&
|
||||
play_state == oboe::StreamState::Started) {
|
||||
LOGI("both streams Started after %d polls", poll_count);
|
||||
streams_started = true;
|
||||
break;
|
||||
}
|
||||
poll_count++;
|
||||
@@ -420,6 +422,18 @@ int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
(int)g_capture_stream->getState(),
|
||||
(int)g_playout_stream->getState(),
|
||||
poll_count);
|
||||
if (!streams_started) {
|
||||
LOGE("Timed out waiting for Oboe streams to reach Started state");
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_rings_valid.store(false, std::memory_order_release);
|
||||
g_capture_stream->requestStop();
|
||||
g_playout_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -6;
|
||||
}
|
||||
}
|
||||
|
||||
LOGI("Oboe started: sr=%d burst=%d ch=%d",
|
||||
|
||||
@@ -26,6 +26,11 @@ pub extern "C" fn wzp_native_version() -> i32 {
|
||||
|
||||
/// Writes a NUL-terminated string into `out` (capped at `cap`) and
|
||||
/// returns bytes written excluding the NUL.
|
||||
///
|
||||
/// # Safety
|
||||
/// `out` must be a valid pointer to at least `cap` contiguous bytes of
|
||||
/// writable memory. Passing a null pointer or zero capacity is safe
|
||||
/// (returns 0), but a dangling non-null pointer is undefined behaviour.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_hello(out: *mut u8, cap: usize) -> usize {
|
||||
const MSG: &[u8] = b"hello from wzp-native\0";
|
||||
@@ -111,7 +116,11 @@ impl RingBuffer {
|
||||
let w = self.write_idx.load(Ordering::Acquire);
|
||||
let r = self.read_idx.load(Ordering::Relaxed);
|
||||
let avail = w - r;
|
||||
if avail < 0 { (avail + self.capacity as i32) as usize } else { avail as usize }
|
||||
if avail < 0 {
|
||||
(avail + self.capacity as i32) as usize
|
||||
} else {
|
||||
avail as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn available_write(&self) -> usize {
|
||||
@@ -127,9 +136,13 @@ impl RingBuffer {
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr() as *mut i16;
|
||||
for sample in &data[..count] {
|
||||
unsafe { *buf_ptr.add(w) = *sample; }
|
||||
unsafe {
|
||||
*buf_ptr.add(w) = *sample;
|
||||
}
|
||||
w += 1;
|
||||
if w >= cap { w = 0; }
|
||||
if w >= cap {
|
||||
w = 0;
|
||||
}
|
||||
}
|
||||
self.write_idx.store(w as i32, Ordering::Release);
|
||||
count
|
||||
@@ -144,9 +157,13 @@ impl RingBuffer {
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr();
|
||||
for slot in &mut out[..count] {
|
||||
unsafe { *slot = *buf_ptr.add(r); }
|
||||
unsafe {
|
||||
*slot = *buf_ptr.add(r);
|
||||
}
|
||||
r += 1;
|
||||
if r >= cap { r = 0; }
|
||||
if r >= cap {
|
||||
r = 0;
|
||||
}
|
||||
}
|
||||
self.read_idx.store(r as i32, Ordering::Release);
|
||||
count
|
||||
@@ -264,9 +281,20 @@ pub extern "C" fn wzp_native_audio_stop() {
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of capture samples available to read without blocking.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_capture_available() -> usize {
|
||||
backend().capture.available_read()
|
||||
}
|
||||
|
||||
/// Read captured PCM samples from the capture ring. Returns the number
|
||||
/// of `i16` samples actually copied into `out` (may be less than
|
||||
/// `out_len` if the ring is empty).
|
||||
///
|
||||
/// # Safety
|
||||
/// `out` must be a valid pointer to `out_len` contiguous `i16` values.
|
||||
/// The caller must ensure no other thread writes to the same buffer
|
||||
/// concurrently. Passing a null pointer or zero length is safe (returns 0).
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: usize) -> usize {
|
||||
if out.is_null() || out_len == 0 {
|
||||
@@ -280,6 +308,12 @@ pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: u
|
||||
/// samples actually enqueued (may be less than `in_len` if the ring
|
||||
/// is nearly full — in practice the caller should pace to 20 ms
|
||||
/// frames and spin briefly if the ring is full).
|
||||
///
|
||||
/// # Safety
|
||||
/// `input` must be a valid pointer to `in_len` contiguous `i16` values
|
||||
/// that remain valid for the duration of the call. Passing a null pointer
|
||||
/// or zero length is safe (returns 0). The caller must not free or mutate
|
||||
/// the buffer while this function is executing.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_len: usize) -> usize {
|
||||
if input.is_null() || in_len == 0 {
|
||||
@@ -294,17 +328,27 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
||||
// has stopped firing → restart the streams. This is the
|
||||
// self-healing behavior that makes rejoin work: teardown +
|
||||
// rebuild clears whatever HAL state locked up the callback.
|
||||
let current_read_idx = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let last_read_idx = b.playout_last_read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let current_read_idx = b
|
||||
.playout
|
||||
.read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let last_read_idx = b
|
||||
.playout_last_read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if current_read_idx == last_read_idx {
|
||||
let stall = b.playout_stall_writes.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let stall = b
|
||||
.playout_stall_writes
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if stall >= 50 {
|
||||
// Callback hasn't drained anything in ~1 second.
|
||||
// Force a stream restart.
|
||||
unsafe {
|
||||
android_log("playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams");
|
||||
android_log(
|
||||
"playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams",
|
||||
);
|
||||
}
|
||||
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_stall_writes
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
// Release the started lock, stop, re-start.
|
||||
// This is the same logic as the Rust-side
|
||||
// audio_stop() + audio_start() but done inline
|
||||
@@ -319,10 +363,18 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
||||
}
|
||||
}
|
||||
// Clear the rings so the restart doesn't read stale data
|
||||
b.playout.write_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout.read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture.write_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture.read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout
|
||||
.write_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout
|
||||
.read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture
|
||||
.write_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture
|
||||
.read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
// Re-start (stall detector — always non-BT mode)
|
||||
let config = WzpOboeConfig {
|
||||
sample_rate: 48_000,
|
||||
@@ -345,30 +397,49 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
||||
if let Ok(mut started) = b.started.lock() {
|
||||
*started = true;
|
||||
}
|
||||
unsafe { android_log("playout restart OK — Oboe streams rebuilt"); }
|
||||
unsafe {
|
||||
android_log("playout restart OK — Oboe streams rebuilt");
|
||||
}
|
||||
} else {
|
||||
unsafe { android_log(&format!("playout restart FAILED: {ret}")); }
|
||||
unsafe {
|
||||
android_log(&format!("playout restart FAILED: {ret}"));
|
||||
}
|
||||
}
|
||||
b.playout_last_read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_last_read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
return 0; // caller will retry on next frame
|
||||
}
|
||||
} else {
|
||||
// read_idx advanced — callback is alive, reset counter
|
||||
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_last_read_idx.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_stall_writes
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_last_read_idx
|
||||
.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let before_w = b.playout.write_idx.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let before_r = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let before_w = b
|
||||
.playout
|
||||
.write_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let before_r = b
|
||||
.playout
|
||||
.read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let written = b.playout.write(slice);
|
||||
// First few writes: log ring state + sample range so we can compare what
|
||||
// engine.rs hands us to what the C++ playout callback reads.
|
||||
let first_writes = b.playout_write_log_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let first_writes = b
|
||||
.playout_write_log_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if first_writes < 3 || first_writes % 50 == 0 {
|
||||
let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64);
|
||||
for &s in slice.iter() {
|
||||
if s < lo { lo = s; }
|
||||
if s > hi { hi = s; }
|
||||
if s < lo {
|
||||
lo = s;
|
||||
}
|
||||
if s > hi {
|
||||
hi = s;
|
||||
}
|
||||
sumsq += (s as i64) * (s as i64);
|
||||
}
|
||||
let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32;
|
||||
@@ -376,7 +447,8 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
||||
let avail_r_after = b.playout.available_read();
|
||||
let msg = format!(
|
||||
"playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}",
|
||||
slice.len(), written
|
||||
slice.len(),
|
||||
written
|
||||
);
|
||||
unsafe {
|
||||
android_log(msg.as_str());
|
||||
@@ -400,7 +472,9 @@ unsafe fn android_log(msg: &str) {
|
||||
let mut buf = Vec::with_capacity(msg.len() + 1);
|
||||
buf.extend_from_slice(msg.as_bytes());
|
||||
buf.push(0);
|
||||
unsafe { __android_log_write(4, tag.as_ptr(), buf.as_ptr()); }
|
||||
unsafe {
|
||||
__android_log_write(4, tag.as_ptr(), buf.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "android"))]
|
||||
|
||||
@@ -20,3 +20,4 @@ tracing = "0.1"
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde_json = "1"
|
||||
bincode = "1"
|
||||
|
||||
@@ -7,10 +7,11 @@
|
||||
//! Control (GCC).
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::packet::QualityReport;
|
||||
use crate::QualityProfile;
|
||||
use crate::packet::QualityReport;
|
||||
|
||||
/// Network congestion state derived from delay and loss signals.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -158,6 +159,16 @@ pub struct BandwidthEstimator {
|
||||
loss_detector: LossBasedDetector,
|
||||
/// Last update timestamp.
|
||||
last_update: Option<Instant>,
|
||||
|
||||
// ── Transport-feedback BWE (T2.2) ──
|
||||
/// Congestion-window-derived bandwidth estimate in bits per second.
|
||||
cwnd_bps: AtomicU64,
|
||||
/// Peer REMB (Receiver Estimated Maximum Bitrate) in bits per second.
|
||||
peer_remb_bps: AtomicU64,
|
||||
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||
smoothed_bps: AtomicU64,
|
||||
/// Last time `smoothed_bps` was updated (UNIX epoch millis).
|
||||
last_smoothed_ms: AtomicU64,
|
||||
}
|
||||
|
||||
/// Multiplicative decrease factor applied on congestion (15% reduction).
|
||||
@@ -179,6 +190,10 @@ impl BandwidthEstimator {
|
||||
delay_detector: DelayBasedDetector::new(),
|
||||
loss_detector: LossBasedDetector::new(),
|
||||
last_update: None,
|
||||
cwnd_bps: AtomicU64::new(0),
|
||||
peer_remb_bps: AtomicU64::new(u64::MAX),
|
||||
smoothed_bps: AtomicU64::new(0),
|
||||
last_smoothed_ms: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,6 +265,64 @@ impl BandwidthEstimator {
|
||||
QualityProfile::CATASTROPHIC
|
||||
}
|
||||
}
|
||||
|
||||
// ── Transport-feedback BWE (T2.2) ──
|
||||
|
||||
/// Update from QUIC path stats.
|
||||
///
|
||||
/// Computes `cwnd_bps = cwnd_bytes * 8 / rtt_s` and feeds it into the
|
||||
/// smoothed estimate.
|
||||
pub fn update_from_path(&self, cwnd_bytes: u64, _bytes_in_flight: u64, rtt_ms: u32) {
|
||||
let rtt_s = rtt_ms.max(1) as f64 / 1000.0;
|
||||
let cwnd_bps = ((cwnd_bytes * 8) as f64 / rtt_s) as u64;
|
||||
self.cwnd_bps.store(cwnd_bps, Relaxed);
|
||||
self.update_smoothed(cwnd_bps);
|
||||
}
|
||||
|
||||
/// Update from a peer's `TransportFeedback` REMB value.
|
||||
pub fn update_from_peer(&self, fb_remb_bps: u32) {
|
||||
let remb = fb_remb_bps as u64;
|
||||
self.peer_remb_bps.store(remb, Relaxed);
|
||||
self.update_smoothed(remb);
|
||||
}
|
||||
|
||||
/// Target sending bitrate in bits per second.
|
||||
///
|
||||
/// Returns 90% of the minimum between the congestion-window estimate
|
||||
/// and the peer REMB estimate.
|
||||
pub fn target_send_bps(&self) -> u64 {
|
||||
let cwnd = self.cwnd_bps.load(Relaxed);
|
||||
let remb = self.peer_remb_bps.load(Relaxed);
|
||||
let m = cwnd.min(remb);
|
||||
(m as f64 * 0.9) as u64
|
||||
}
|
||||
|
||||
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||
pub fn smoothed_bps(&self) -> u64 {
|
||||
self.smoothed_bps.load(Relaxed)
|
||||
}
|
||||
|
||||
/// Apply EWMA smoothing with a 2-second half-life.
|
||||
fn update_smoothed(&self, new_bps: u64) {
|
||||
let now_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
let last_ms = self.last_smoothed_ms.load(Relaxed);
|
||||
let dt_ms = now_ms.saturating_sub(last_ms);
|
||||
|
||||
let current = self.smoothed_bps.load(Relaxed);
|
||||
let updated = if current == 0 || dt_ms == 0 {
|
||||
new_bps
|
||||
} else {
|
||||
let alpha = 1.0 - 0.5_f64.powf(dt_ms as f64 / 2000.0);
|
||||
let s = current as f64 * (1.0 - alpha) + new_bps as f64 * alpha;
|
||||
s as u64
|
||||
};
|
||||
|
||||
self.smoothed_bps.store(updated, Relaxed);
|
||||
self.last_smoothed_ms.store(now_ms, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -396,10 +469,7 @@ mod tests {
|
||||
|
||||
// Below 8 => CATASTROPHIC
|
||||
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
|
||||
assert_eq!(
|
||||
bwe_cat.recommended_profile(),
|
||||
QualityProfile::CATASTROPHIC
|
||||
);
|
||||
assert_eq!(bwe_cat.recommended_profile(), QualityProfile::CATASTROPHIC);
|
||||
|
||||
// High bandwidth
|
||||
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
|
||||
@@ -413,7 +483,7 @@ mod tests {
|
||||
// Build a QualityReport with moderate loss and RTT.
|
||||
let report = QualityReport {
|
||||
loss_pct: (10.0_f32 / 100.0 * 255.0) as u8, // ~10% loss
|
||||
rtt_4ms: 25, // 100ms RTT
|
||||
rtt_4ms: 25, // 100ms RTT
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
};
|
||||
@@ -451,4 +521,46 @@ mod tests {
|
||||
}
|
||||
assert!(det.is_congested());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_send_bps_uses_min_of_cwnd_and_remb() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// cwnd_bps = 100_000, remb = 200_000 → min = 100_000 → 90%
|
||||
bwe.update_from_path(1250, 0, 100); // 1250*8 / 0.1 = 100_000
|
||||
bwe.update_from_peer(200_000);
|
||||
assert_eq!(bwe.target_send_bps(), 90_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_send_bps_with_zero_cwnd_uses_remb() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// Default cwnd is 0, remb is u64::MAX (default).
|
||||
// 0.min(u64::MAX) = 0 → 90% = 0
|
||||
assert_eq!(bwe.target_send_bps(), 0);
|
||||
|
||||
bwe.update_from_peer(100_000);
|
||||
// cwnd still 0
|
||||
assert_eq!(bwe.target_send_bps(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smoothed_bps_ewma_converges() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
bwe.update_from_path(1250, 0, 100); // 100_000 bps
|
||||
let s1 = bwe.smoothed_bps();
|
||||
assert_eq!(s1, 100_000);
|
||||
|
||||
// Immediately update with same value — dt ≈ 0, so should stay at 100_000
|
||||
bwe.update_from_path(1250, 0, 100);
|
||||
let s2 = bwe.smoothed_bps();
|
||||
assert_eq!(s2, 100_000);
|
||||
|
||||
// Sleep a bit so dt is non-zero, then update with a much higher value.
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
bwe.update_from_path(12500, 0, 100); // 1_000_000 bps
|
||||
let s3 = bwe.smoothed_bps();
|
||||
assert!(s3 > 100_000, "smoothed should increase toward 1M: {s3}");
|
||||
// With 100ms dt, alpha ≈ 0.03, so smoothed should be ~100k * 0.97 + 1M * 0.03 ≈ 127k
|
||||
assert!(s3 < 500_000, "smoothed should not jump too far: {s3}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Identifies the audio codec and bitrate configuration.
|
||||
///
|
||||
/// Encoded as 4 bits in the media packet header.
|
||||
/// Encoded as 4 bits in the v1 media packet header, and as a full 8-bit
|
||||
/// value in the v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum CodecId {
|
||||
@@ -24,6 +25,16 @@ pub enum CodecId {
|
||||
Opus48k = 7,
|
||||
/// Opus at 64kbps (studio high)
|
||||
Opus64k = 8,
|
||||
/// H.264 baseline profile (video).
|
||||
H264Baseline = 9,
|
||||
// Reserved for video codecs; implementations land in PRD-video-multicodec.
|
||||
// 10 => H264 main
|
||||
// 11 => H265 main
|
||||
// 13 => VP9
|
||||
/// AV1 main profile (video).
|
||||
Av1Main = 12,
|
||||
/// H.265 main profile (video).
|
||||
H265Main = 11,
|
||||
}
|
||||
|
||||
impl CodecId {
|
||||
@@ -39,6 +50,7 @@ impl CodecId {
|
||||
Self::Codec2_3200 => 3_200,
|
||||
Self::Codec2_1200 => 1_200,
|
||||
Self::ComfortNoise => 0,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 2_000_000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,16 +62,22 @@ impl CodecId {
|
||||
Self::Codec2_3200 => 20,
|
||||
Self::Codec2_1200 => 40,
|
||||
Self::ComfortNoise => 20,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 33,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample rate expected by this codec.
|
||||
pub const fn sample_rate_hz(self) -> u32 {
|
||||
match self {
|
||||
Self::Opus24k | Self::Opus16k | Self::Opus6k
|
||||
| Self::Opus32k | Self::Opus48k | Self::Opus64k => 48_000,
|
||||
Self::Opus24k
|
||||
| Self::Opus16k
|
||||
| Self::Opus6k
|
||||
| Self::Opus32k
|
||||
| Self::Opus48k
|
||||
| Self::Opus64k => 48_000,
|
||||
Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
|
||||
Self::ComfortNoise => 48_000,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 48_000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +93,9 @@ impl CodecId {
|
||||
6 => Some(Self::Opus32k),
|
||||
7 => Some(Self::Opus48k),
|
||||
8 => Some(Self::Opus64k),
|
||||
9 => Some(Self::H264Baseline),
|
||||
11 => Some(Self::H265Main),
|
||||
12 => Some(Self::Av1Main),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -84,10 +105,22 @@ impl CodecId {
|
||||
self as u8
|
||||
}
|
||||
|
||||
/// Returns true if this is a video codec variant.
|
||||
pub const fn is_video(self) -> bool {
|
||||
matches!(self, Self::H264Baseline | Self::H265Main | Self::Av1Main)
|
||||
}
|
||||
|
||||
/// Returns true if this is an Opus variant.
|
||||
pub const fn is_opus(self) -> bool {
|
||||
matches!(self, Self::Opus6k | Self::Opus16k | Self::Opus24k
|
||||
| Self::Opus32k | Self::Opus48k | Self::Opus64k)
|
||||
matches!(
|
||||
self,
|
||||
Self::Opus6k
|
||||
| Self::Opus16k
|
||||
| Self::Opus24k
|
||||
| Self::Opus32k
|
||||
| Self::Opus48k
|
||||
| Self::Opus64k
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +135,18 @@ pub struct QualityProfile {
|
||||
pub frame_duration_ms: u8,
|
||||
/// Number of source frames per FEC block.
|
||||
pub frames_per_block: u8,
|
||||
/// Bandwidth-allocation priority between audio and video.
|
||||
#[serde(default)]
|
||||
pub priority_mode: crate::PriorityMode,
|
||||
/// Target video bitrate in kbps (set by quality controller, not handshake).
|
||||
#[serde(default)]
|
||||
pub video_bitrate_kbps: Option<u32>,
|
||||
/// Target video resolution as (width, height).
|
||||
#[serde(default)]
|
||||
pub video_resolution: Option<(u16, u16)>,
|
||||
/// Target video frame rate.
|
||||
#[serde(default)]
|
||||
pub video_fps: Option<u8>,
|
||||
}
|
||||
|
||||
impl QualityProfile {
|
||||
@@ -111,6 +156,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.2,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Degraded conditions: Opus 6kbps, moderate FEC.
|
||||
@@ -119,6 +168,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 40,
|
||||
frames_per_block: 10,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Catastrophic conditions: Codec2 1.2kbps, heavy FEC.
|
||||
@@ -127,6 +180,10 @@ impl QualityProfile {
|
||||
fec_ratio: 1.0,
|
||||
frame_duration_ms: 40,
|
||||
frames_per_block: 8,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio low: Opus 32kbps, minimal FEC.
|
||||
@@ -135,6 +192,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio: Opus 48kbps, minimal FEC.
|
||||
@@ -143,6 +204,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio high: Opus 64kbps, minimal FEC.
|
||||
@@ -151,6 +216,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Estimated total bandwidth in kbps including FEC overhead.
|
||||
@@ -159,3 +228,46 @@ impl QualityProfile {
|
||||
base * (1.0 + self.fec_ratio)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{CodecId, QualityProfile};
|
||||
use crate::PriorityMode;
|
||||
|
||||
#[test]
|
||||
fn codec_id_unknown_values_rejected() {
|
||||
for v in [10u8, 13].iter().copied().chain(14u8..=255) {
|
||||
assert!(CodecId::from_wire(v).is_none(), "v={v}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn h265_main_roundtrips() {
|
||||
assert_eq!(CodecId::H265Main.to_wire(), 11);
|
||||
assert_eq!(CodecId::from_wire(11), Some(CodecId::H265Main));
|
||||
assert!(CodecId::H265Main.is_video());
|
||||
assert_eq!(CodecId::H265Main.bitrate_bps(), 2_000_000);
|
||||
assert_eq!(CodecId::H265Main.frame_duration_ms(), 33);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn av1_main_roundtrips() {
|
||||
assert_eq!(CodecId::Av1Main.to_wire(), 12);
|
||||
assert_eq!(CodecId::from_wire(12), Some(CodecId::Av1Main));
|
||||
assert!(CodecId::Av1Main.is_video());
|
||||
assert_eq!(CodecId::Av1Main.bitrate_bps(), 2_000_000);
|
||||
assert_eq!(CodecId::Av1Main.frame_duration_ms(), 33);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quality_profile_backward_compat_old_json() {
|
||||
// Old JSON emitted before T5.1 has no priority_mode or video fields.
|
||||
let old_json =
|
||||
r#"{"codec":"Opus24k","fec_ratio":0.2,"frame_duration_ms":20,"frames_per_block":5}"#;
|
||||
let parsed: QualityProfile = serde_json::from_str(old_json).unwrap();
|
||||
assert_eq!(parsed.priority_mode, PriorityMode::AudioFirst);
|
||||
assert_eq!(parsed.video_bitrate_kbps, None);
|
||||
assert_eq!(parsed.video_resolution, None);
|
||||
assert_eq!(parsed.video_fps, None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@
|
||||
//! prediction): when jitter variance spikes >30% over a 200 ms window — typical
|
||||
//! of Starlink satellite handovers — it temporarily boosts DRED to the maximum
|
||||
//! allowed for the current codec before packets actually start dropping.
|
||||
//!
|
||||
//! See also: [`crate::quality`] for discrete tier classification that drives
|
||||
//! codec switching. DredTuner operates within a tier, adjusting DRED
|
||||
//! parameters continuously based on live network metrics.
|
||||
|
||||
use crate::CodecId;
|
||||
|
||||
@@ -45,7 +49,7 @@ fn baseline_dred_frames(codec: CodecId) -> u8 {
|
||||
match codec {
|
||||
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10, // 100 ms
|
||||
CodecId::Opus16k | CodecId::Opus24k => 20, // 200 ms
|
||||
CodecId::Opus6k => 50, // 500 ms
|
||||
CodecId::Opus6k => 50, // 500 ms
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
@@ -124,7 +128,11 @@ impl DredTuner {
|
||||
self.initialized = true;
|
||||
} else {
|
||||
// Fast-up (alpha=0.3), slow-down (alpha=0.05) asymmetric EWMA
|
||||
let alpha = if jitter_f > self.jitter_ewma { 0.3 } else { 0.05 };
|
||||
let alpha = if jitter_f > self.jitter_ewma {
|
||||
0.3
|
||||
} else {
|
||||
0.05
|
||||
};
|
||||
self.jitter_ewma = alpha * jitter_f + (1.0 - alpha) * self.jitter_ewma;
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ pub enum CryptoError {
|
||||
#[error("rekey failed: {0}")]
|
||||
RekeyFailed(String),
|
||||
#[error("anti-replay: duplicate or old packet (seq={seq})")]
|
||||
ReplayDetected { seq: u16 },
|
||||
ReplayDetected { seq: u32 },
|
||||
#[error("internal crypto error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
@@ -81,9 +81,7 @@ impl AdaptivePlayoutDelay {
|
||||
let jitter = (actual_delta - expected_delta).abs();
|
||||
|
||||
// Spike detection: check before EMA update
|
||||
if self.jitter_ema > 0.0
|
||||
&& jitter > self.jitter_ema * self.spike_threshold_multiplier
|
||||
{
|
||||
if self.jitter_ema > 0.0 && jitter > self.jitter_ema * self.spike_threshold_multiplier {
|
||||
self.spike_detected_at = Some(Instant::now());
|
||||
}
|
||||
|
||||
@@ -107,10 +105,8 @@ impl AdaptivePlayoutDelay {
|
||||
self.target_delay = self.max_delay;
|
||||
} else {
|
||||
// Convert jitter estimate to target delay in packets
|
||||
let raw_target =
|
||||
(self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||
self.target_delay =
|
||||
(raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||
let raw_target = (self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||
self.target_delay = (raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,9 +158,9 @@ impl AdaptivePlayoutDelay {
|
||||
/// Manages packet reordering, gap detection, and signals when PLC is needed.
|
||||
pub struct JitterBuffer {
|
||||
/// Packets waiting to be consumed, ordered by sequence number.
|
||||
buffer: BTreeMap<u16, MediaPacket>,
|
||||
buffer: BTreeMap<u32, MediaPacket>,
|
||||
/// Next sequence number expected for playout.
|
||||
next_playout_seq: u16,
|
||||
next_playout_seq: u32,
|
||||
/// Maximum buffer depth in number of packets.
|
||||
max_depth: usize,
|
||||
/// Target buffer depth (adaptive, based on jitter).
|
||||
@@ -204,7 +200,7 @@ pub enum PlayoutResult {
|
||||
/// A packet is available for playout.
|
||||
Packet(MediaPacket),
|
||||
/// The expected packet is missing — decoder should generate PLC.
|
||||
Missing { seq: u16 },
|
||||
Missing { seq: u32 },
|
||||
/// Buffer is empty or not yet filled to target depth.
|
||||
NotReady,
|
||||
}
|
||||
@@ -278,9 +274,18 @@ impl JitterBuffer {
|
||||
// federation room — reset instead of dropping.
|
||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected");
|
||||
tracing::warn!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
backward_distance,
|
||||
"jitter: backward seq detected"
|
||||
);
|
||||
if backward_distance > 100 {
|
||||
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected");
|
||||
tracing::info!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
"jitter: RESET — new sender detected"
|
||||
);
|
||||
self.buffer.clear();
|
||||
self.next_playout_seq = seq;
|
||||
self.stats.packets_late = 0;
|
||||
@@ -428,9 +433,18 @@ impl JitterBuffer {
|
||||
// federation room — reset instead of dropping.
|
||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected");
|
||||
tracing::warn!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
backward_distance,
|
||||
"jitter: backward seq detected"
|
||||
);
|
||||
if backward_distance > 100 {
|
||||
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected");
|
||||
tracing::info!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
"jitter: RESET — new sender detected"
|
||||
);
|
||||
self.buffer.clear();
|
||||
self.next_playout_seq = seq;
|
||||
self.stats.packets_late = 0;
|
||||
@@ -489,7 +503,7 @@ impl JitterBuffer {
|
||||
|
||||
/// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic).
|
||||
/// Returns true if `a` comes before `b` in sequence space.
|
||||
fn seq_before(a: u16, b: u16) -> bool {
|
||||
fn seq_before(a: u32, b: u32) -> bool {
|
||||
let diff = b.wrapping_sub(a);
|
||||
diff > 0 && diff < 0x8000
|
||||
}
|
||||
@@ -497,24 +511,23 @@ fn seq_before(a: u16, b: u16) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::CodecId;
|
||||
use crate::MediaType;
|
||||
use crate::packet::{MediaHeader, MediaPacket};
|
||||
use bytes::Bytes;
|
||||
use crate::CodecId;
|
||||
|
||||
fn make_packet(seq: u16) -> MediaPacket {
|
||||
fn make_packet(seq: u32) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq as u32 * 20,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0u8; 60]),
|
||||
quality_report: None,
|
||||
@@ -598,7 +611,7 @@ mod tests {
|
||||
fn seq_before_wrapping() {
|
||||
assert!(seq_before(0, 1));
|
||||
assert!(seq_before(65534, 65535));
|
||||
assert!(seq_before(65535, 0)); // wrap
|
||||
assert!(seq_before(u32::MAX, 0)); // wrap
|
||||
assert!(!seq_before(1, 0));
|
||||
assert!(!seq_before(5, 5)); // equal
|
||||
}
|
||||
@@ -800,7 +813,7 @@ mod tests {
|
||||
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
||||
|
||||
// Push packets with consistent timing
|
||||
for i in 0u16..20 {
|
||||
for i in 0u32..20 {
|
||||
let pkt = make_packet(i);
|
||||
let arrival_ms = i as u64 * 20;
|
||||
jb.push_with_arrival(pkt, arrival_ms);
|
||||
|
||||
@@ -17,21 +17,25 @@ pub mod codec_id;
|
||||
pub mod dred_tuner;
|
||||
pub mod error;
|
||||
pub mod jitter;
|
||||
pub mod media_type;
|
||||
pub mod packet;
|
||||
pub mod priority_mode;
|
||||
pub mod quality;
|
||||
pub mod session;
|
||||
pub mod traits;
|
||||
|
||||
// Re-export key types at crate root for convenience.
|
||||
pub use codec_id::{CodecId, QualityProfile};
|
||||
pub use error::*;
|
||||
pub use packet::{
|
||||
CallAcceptMode, HangupReason, MediaHeader, MediaPacket, MiniFrameContext, MiniHeader,
|
||||
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL,
|
||||
FRAME_TYPE_MINI,
|
||||
};
|
||||
pub use bandwidth::{BandwidthEstimator, CongestionState};
|
||||
pub use codec_id::{CodecId, QualityProfile};
|
||||
pub use dred_tuner::{DredTuner, DredTuning};
|
||||
pub use error::*;
|
||||
pub use media_type::MediaType;
|
||||
pub use packet::{
|
||||
CallAcceptMode, FRAME_TYPE_FULL, FRAME_TYPE_MINI, HangupReason, MediaHeader, MediaHeaderV2,
|
||||
MediaPacket, MiniFrameContext, MiniFrameContextV2, MiniHeader, MiniHeaderV2, PresenceUser,
|
||||
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, default_signal_version,
|
||||
};
|
||||
pub use priority_mode::PriorityMode;
|
||||
pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
|
||||
pub use session::{Session, SessionEvent, SessionState};
|
||||
pub use traits::*;
|
||||
|
||||
57
crates/wzp-proto/src/media_type.rs
Normal file
57
crates/wzp-proto/src/media_type.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Media stream type carried in a v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum MediaType {
|
||||
/// Encoded speech / music (Opus, Codec2, ComfortNoise).
|
||||
Audio = 0,
|
||||
/// Encoded video access unit (H.264, H.265, AV1; PRD-video-multicodec).
|
||||
Video = 1,
|
||||
/// Opaque payload not interpreted by the relay (reserved).
|
||||
Data = 2,
|
||||
/// In-band control message carried on the media plane (reserved).
|
||||
Control = 3,
|
||||
}
|
||||
|
||||
impl MediaType {
|
||||
/// Encode to the wire byte representation (`self as u8`).
|
||||
pub const fn to_wire(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
/// Decode from a wire byte. Returns `None` for values outside 0..=3.
|
||||
pub const fn from_wire(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Audio),
|
||||
1 => Some(Self::Video),
|
||||
2 => Some(Self::Data),
|
||||
3 => Some(Self::Control),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn media_type_roundtrip() {
|
||||
for mt in [
|
||||
MediaType::Audio,
|
||||
MediaType::Video,
|
||||
MediaType::Data,
|
||||
MediaType::Control,
|
||||
] {
|
||||
assert_eq!(MediaType::from_wire(mt.to_wire()), Some(mt));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn media_type_unknown_rejected() {
|
||||
for v in 4u8..=255 {
|
||||
assert!(MediaType::from_wire(v).is_none(), "v={v}");
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
34
crates/wzp-proto/src/priority_mode.rs
Normal file
34
crates/wzp-proto/src/priority_mode.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
//! Priority mode for bandwidth allocation between audio and video.
|
||||
//!
|
||||
//! See `docs/PRD/PRD-video-quality-priority.md` for the full design.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Bandwidth-allocation policy between audio and video.
|
||||
///
|
||||
/// Carried on [`QualityProfile`](crate::QualityProfile) and mutable at
|
||||
/// runtime via [`SignalMessage::SetPriorityMode`](crate::SignalMessage).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub enum PriorityMode {
|
||||
/// Audio gets its floor first; video gets the remainder.
|
||||
/// Default for voice/video calls.
|
||||
#[default]
|
||||
AudioFirst,
|
||||
/// Video gets its floor first; audio degrades to Opus 16k floor.
|
||||
VideoFirst,
|
||||
/// Audio clamped to 16 kbps (intelligible speech); video gets remainder.
|
||||
/// Falls back to slide mode when bandwidth drops below SD floor.
|
||||
ScreenShare,
|
||||
/// Proportional split (~15 % audio, ~85 % video).
|
||||
Balanced,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn priority_mode_default_is_audio_first() {
|
||||
assert_eq!(PriorityMode::default(), PriorityMode::AudioFirst);
|
||||
}
|
||||
}
|
||||
@@ -1,24 +1,40 @@
|
||||
//! See also: [`crate::dred_tuner`] for continuous DRED tuning within a tier.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::BandwidthEstimator;
|
||||
use crate::QualityProfile;
|
||||
use crate::packet::QualityReport;
|
||||
use crate::traits::QualityController;
|
||||
use crate::QualityProfile;
|
||||
|
||||
/// Network quality tier — drives codec and FEC selection.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
///
|
||||
/// 5-tier range from studio quality down to catastrophic:
|
||||
/// Studio64k > Studio48k > Studio32k > Good > Degraded > Catastrophic
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum Tier {
|
||||
/// loss < 10%, RTT < 400ms
|
||||
Good,
|
||||
/// loss 10-40% OR RTT 400-600ms
|
||||
Degraded,
|
||||
/// loss > 40% OR RTT > 600ms
|
||||
Catastrophic,
|
||||
/// loss >= 15% OR RTT >= 200ms — Codec2 1.2k
|
||||
Catastrophic = 0,
|
||||
/// loss < 15% AND RTT < 200ms — Opus 6k
|
||||
Degraded = 1,
|
||||
/// loss < 5% AND RTT < 100ms — Opus 24k
|
||||
Good = 2,
|
||||
/// loss < 2% AND RTT < 80ms — Opus 32k
|
||||
Studio32k = 3,
|
||||
/// loss < 1% AND RTT < 50ms — Opus 48k
|
||||
Studio48k = 4,
|
||||
/// loss < 1% AND RTT < 30ms — Opus 64k
|
||||
Studio64k = 5,
|
||||
}
|
||||
|
||||
impl Tier {
|
||||
pub fn profile(self) -> QualityProfile {
|
||||
match self {
|
||||
Self::Studio64k => QualityProfile::STUDIO_64K,
|
||||
Self::Studio48k => QualityProfile::STUDIO_48K,
|
||||
Self::Studio32k => QualityProfile::STUDIO_32K,
|
||||
Self::Good => QualityProfile::GOOD,
|
||||
Self::Degraded => QualityProfile::DEGRADED,
|
||||
Self::Catastrophic => QualityProfile::CATASTROPHIC,
|
||||
@@ -39,7 +55,7 @@ impl Tier {
|
||||
NetworkContext::CellularLte
|
||||
| NetworkContext::Cellular5g
|
||||
| NetworkContext::Cellular3g => {
|
||||
// Tighter thresholds for cellular networks
|
||||
// Tighter thresholds for cellular — no studio tiers
|
||||
if loss > 25.0 || rtt > 500 {
|
||||
Self::Catastrophic
|
||||
} else if loss > 8.0 || rtt > 300 {
|
||||
@@ -49,13 +65,18 @@ impl Tier {
|
||||
}
|
||||
}
|
||||
NetworkContext::WiFi | NetworkContext::Unknown => {
|
||||
// Original thresholds
|
||||
if loss > 40.0 || rtt > 600 {
|
||||
if loss >= 15.0 || rtt >= 200 {
|
||||
Self::Catastrophic
|
||||
} else if loss > 10.0 || rtt > 400 {
|
||||
} else if loss >= 5.0 || rtt >= 100 {
|
||||
Self::Degraded
|
||||
} else {
|
||||
} else if loss >= 2.0 || rtt >= 80 {
|
||||
Self::Good
|
||||
} else if loss >= 1.0 || rtt >= 50 {
|
||||
Self::Studio32k
|
||||
} else if rtt >= 30 {
|
||||
Self::Studio48k
|
||||
} else {
|
||||
Self::Studio64k
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,29 +85,32 @@ impl Tier {
|
||||
/// Return the next lower (worse) tier, or None if already at the worst.
|
||||
pub fn downgrade(self) -> Option<Tier> {
|
||||
match self {
|
||||
Self::Studio64k => Some(Self::Studio48k),
|
||||
Self::Studio48k => Some(Self::Studio32k),
|
||||
Self::Studio32k => Some(Self::Good),
|
||||
Self::Good => Some(Self::Degraded),
|
||||
Self::Degraded => Some(Self::Catastrophic),
|
||||
Self::Catastrophic => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this is a studio tier (above Good).
|
||||
pub fn is_studio(self) -> bool {
|
||||
matches!(self, Self::Studio64k | Self::Studio48k | Self::Studio32k)
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes the network transport type for context-aware quality decisions.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub enum NetworkContext {
|
||||
WiFi,
|
||||
CellularLte,
|
||||
Cellular5g,
|
||||
Cellular3g,
|
||||
#[default]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Default for NetworkContext {
|
||||
fn default() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Adaptive quality controller with hysteresis to prevent tier flapping.
|
||||
///
|
||||
/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular)
|
||||
@@ -108,20 +132,50 @@ pub struct AdaptiveQualityController {
|
||||
fec_boost_until: Option<Instant>,
|
||||
/// FEC boost amount to add during handoff recovery window.
|
||||
fec_boost_amount: f32,
|
||||
/// Probing state: when Some, we're actively testing a higher tier.
|
||||
probe: Option<ProbeState>,
|
||||
/// Time spent stable at the current tier (for probe trigger).
|
||||
stable_since: Option<Instant>,
|
||||
/// Optional bandwidth estimator for BWE-guarded upgrades.
|
||||
bwe: Option<Arc<BandwidthEstimator>>,
|
||||
}
|
||||
|
||||
/// Threshold for downgrading (fast reaction to degradation).
|
||||
const DOWNGRADE_THRESHOLD: u32 = 3;
|
||||
/// Threshold for downgrading on cellular networks (even faster).
|
||||
const CELLULAR_DOWNGRADE_THRESHOLD: u32 = 2;
|
||||
/// Threshold for upgrading (slow, cautious improvement).
|
||||
const UPGRADE_THRESHOLD: u32 = 10;
|
||||
/// Threshold for upgrading from Catastrophic/Degraded to Good.
|
||||
const UPGRADE_THRESHOLD: u32 = 5;
|
||||
/// Threshold for upgrading into studio tiers (very conservative).
|
||||
const STUDIO_UPGRADE_THRESHOLD: u32 = 10;
|
||||
/// Maximum history window size.
|
||||
const HISTORY_SIZE: usize = 20;
|
||||
/// Default FEC boost amount during handoff recovery.
|
||||
const DEFAULT_FEC_BOOST: f32 = 0.2;
|
||||
/// Duration of FEC boost after a network handoff.
|
||||
const FEC_BOOST_DURATION_SECS: u64 = 10;
|
||||
/// Minimum time stable at current tier before probing upward (30 seconds).
|
||||
const PROBE_STABLE_SECS: u64 = 30;
|
||||
/// Duration of a probe window (5 seconds — ~25 quality reports at 1/s).
|
||||
const PROBE_DURATION_SECS: u64 = 5;
|
||||
/// Maximum bad reports during probe before aborting (1 out of ~5 = 20%).
|
||||
const PROBE_MAX_BAD: u32 = 1;
|
||||
/// Cooldown after a failed probe before trying again (60 seconds).
|
||||
const PROBE_COOLDOWN_SECS: u64 = 60;
|
||||
|
||||
/// Active bandwidth probe state.
|
||||
struct ProbeState {
|
||||
/// The tier we're probing (one step above current).
|
||||
target_tier: Tier,
|
||||
/// Profile to apply during probe.
|
||||
target_profile: QualityProfile,
|
||||
/// When the probe started.
|
||||
started: Instant,
|
||||
/// Reports observed during probe.
|
||||
probe_reports: u32,
|
||||
/// Bad reports during probe (loss/RTT exceeded target tier thresholds).
|
||||
bad_reports: u32,
|
||||
}
|
||||
|
||||
impl AdaptiveQualityController {
|
||||
pub fn new() -> Self {
|
||||
@@ -135,6 +189,9 @@ impl AdaptiveQualityController {
|
||||
network_context: NetworkContext::default(),
|
||||
fec_boost_until: None,
|
||||
fec_boost_amount: DEFAULT_FEC_BOOST,
|
||||
probe: None,
|
||||
stable_since: None,
|
||||
bwe: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,6 +231,10 @@ impl AdaptiveQualityController {
|
||||
self.forced = false;
|
||||
}
|
||||
|
||||
// Cancel any active probe
|
||||
self.probe = None;
|
||||
self.stable_since = None;
|
||||
|
||||
// Activate FEC boost for any network change
|
||||
self.fec_boost_until = Some(Instant::now() + Duration::from_secs(FEC_BOOST_DURATION_SECS));
|
||||
}
|
||||
@@ -194,6 +255,19 @@ impl AdaptiveQualityController {
|
||||
pub fn reset_counters(&mut self) {
|
||||
self.consecutive_up = 0;
|
||||
self.consecutive_down = 0;
|
||||
self.probe = None;
|
||||
self.stable_since = None;
|
||||
}
|
||||
|
||||
/// Attach a bandwidth estimator for BWE-guarded tier transitions.
|
||||
pub fn set_bandwidth_estimator(&mut self, bwe: Arc<BandwidthEstimator>) {
|
||||
self.bwe = Some(bwe);
|
||||
}
|
||||
|
||||
/// Return the bitrate ceiling (in bps) for a given tier, including FEC overhead.
|
||||
fn tier_ceiling_bps(tier: Tier) -> u64 {
|
||||
let kbps = tier.profile().total_bitrate_kbps();
|
||||
(kbps * 1000.0) as u64
|
||||
}
|
||||
|
||||
/// Get the effective downgrade threshold based on network context.
|
||||
@@ -213,16 +287,13 @@ impl AdaptiveQualityController {
|
||||
return None;
|
||||
}
|
||||
|
||||
let is_worse = match (self.current_tier, observed_tier) {
|
||||
(Tier::Good, Tier::Degraded | Tier::Catastrophic) => true,
|
||||
(Tier::Degraded, Tier::Catastrophic) => true,
|
||||
_ => false,
|
||||
};
|
||||
let is_worse = observed_tier < self.current_tier;
|
||||
|
||||
if is_worse {
|
||||
self.consecutive_up = 0;
|
||||
self.consecutive_down += 1;
|
||||
if self.consecutive_down >= self.downgrade_threshold() {
|
||||
// Jump directly to the observed tier (don't step one-at-a-time on downgrade)
|
||||
self.current_tier = observed_tier;
|
||||
self.current_profile = observed_tier.profile();
|
||||
self.consecutive_down = 0;
|
||||
@@ -232,22 +303,123 @@ impl AdaptiveQualityController {
|
||||
// Better conditions
|
||||
self.consecutive_down = 0;
|
||||
self.consecutive_up += 1;
|
||||
if self.consecutive_up >= UPGRADE_THRESHOLD {
|
||||
// Studio tiers require more consecutive good reports
|
||||
let threshold = if self.current_tier >= Tier::Good {
|
||||
STUDIO_UPGRADE_THRESHOLD
|
||||
} else {
|
||||
UPGRADE_THRESHOLD
|
||||
};
|
||||
if self.consecutive_up >= threshold {
|
||||
// Only upgrade one step at a time
|
||||
let next_tier = match self.current_tier {
|
||||
Tier::Catastrophic => Tier::Degraded,
|
||||
Tier::Degraded => Tier::Good,
|
||||
Tier::Good => return None,
|
||||
};
|
||||
self.current_tier = next_tier;
|
||||
self.current_profile = next_tier.profile();
|
||||
self.consecutive_up = 0;
|
||||
return Some(self.current_profile);
|
||||
if let Some(next_tier) = self.upgrade_one_step() {
|
||||
// BWE guard: require 130% headroom over target tier bitrate
|
||||
if let Some(ref bwe) = self.bwe {
|
||||
let required = (Self::tier_ceiling_bps(next_tier) * 130) / 100;
|
||||
if bwe.target_send_bps() < required {
|
||||
// Insufficient bandwidth — reset counter to prevent flapping
|
||||
self.consecutive_up = 0;
|
||||
return None;
|
||||
}
|
||||
}
|
||||
self.current_tier = next_tier;
|
||||
self.current_profile = next_tier.profile();
|
||||
self.consecutive_up = 0;
|
||||
return Some(self.current_profile);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check whether to start, continue, or conclude a bandwidth probe.
|
||||
///
|
||||
/// Called from `observe()` when no hysteresis transition fired.
|
||||
fn check_probe(&mut self, observed_tier: Tier) -> Option<QualityProfile> {
|
||||
// Don't probe if forced, or if already at highest tier, or on cellular
|
||||
if self.forced || self.current_tier == Tier::Studio64k {
|
||||
return None;
|
||||
}
|
||||
if matches!(
|
||||
self.network_context,
|
||||
NetworkContext::CellularLte | NetworkContext::Cellular5g | NetworkContext::Cellular3g
|
||||
) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// If we have an active probe, evaluate it
|
||||
if let Some(ref mut probe) = self.probe {
|
||||
probe.probe_reports += 1;
|
||||
|
||||
// Check if the observed tier meets the probe target
|
||||
if observed_tier < probe.target_tier {
|
||||
probe.bad_reports += 1;
|
||||
}
|
||||
|
||||
// Probe failed: too many bad reports
|
||||
if probe.bad_reports > PROBE_MAX_BAD {
|
||||
let _failed_probe = self.probe.take();
|
||||
// Reset stable_since to trigger cooldown
|
||||
self.stable_since = Some(Instant::now() + Duration::from_secs(PROBE_COOLDOWN_SECS));
|
||||
return None; // stay at current tier
|
||||
}
|
||||
|
||||
// Probe succeeded: enough good reports within the window
|
||||
if probe.started.elapsed() >= Duration::from_secs(PROBE_DURATION_SECS) {
|
||||
let target = probe.target_tier;
|
||||
let profile = probe.target_profile;
|
||||
self.probe.take();
|
||||
self.current_tier = target;
|
||||
self.current_profile = profile;
|
||||
self.consecutive_up = 0;
|
||||
self.stable_since = Some(Instant::now());
|
||||
return Some(profile);
|
||||
}
|
||||
|
||||
return None; // probe still running
|
||||
}
|
||||
|
||||
// No active probe — check if we should start one
|
||||
if observed_tier >= self.current_tier {
|
||||
// Track stability
|
||||
if self.stable_since.is_none() {
|
||||
self.stable_since = Some(Instant::now());
|
||||
}
|
||||
|
||||
if let Some(stable_since) = self.stable_since {
|
||||
if stable_since.elapsed() >= Duration::from_secs(PROBE_STABLE_SECS) {
|
||||
// Stable long enough — start probing
|
||||
if let Some(next) = self.upgrade_one_step() {
|
||||
self.probe = Some(ProbeState {
|
||||
target_tier: next,
|
||||
target_profile: next.profile(),
|
||||
started: Instant::now(),
|
||||
probe_reports: 0,
|
||||
bad_reports: 0,
|
||||
});
|
||||
// Return the probe profile so the encoder switches
|
||||
return Some(next.profile());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Conditions degraded — reset stability timer
|
||||
self.stable_since = None;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn upgrade_one_step(&self) -> Option<Tier> {
|
||||
match self.current_tier {
|
||||
Tier::Catastrophic => Some(Tier::Degraded),
|
||||
Tier::Degraded => Some(Tier::Good),
|
||||
Tier::Good => Some(Tier::Studio32k),
|
||||
Tier::Studio32k => Some(Tier::Studio48k),
|
||||
Tier::Studio48k => Some(Tier::Studio64k),
|
||||
Tier::Studio64k => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AdaptiveQualityController {
|
||||
@@ -269,7 +441,17 @@ impl QualityController for AdaptiveQualityController {
|
||||
}
|
||||
|
||||
let observed = Tier::classify_with_context(report, self.network_context);
|
||||
self.try_transition(observed)
|
||||
|
||||
// First check for downgrades/upgrades via hysteresis
|
||||
if let Some(profile) = self.try_transition(observed) {
|
||||
// Cancel any active probe on tier change
|
||||
self.probe.take();
|
||||
self.stable_since = None;
|
||||
return Some(profile);
|
||||
}
|
||||
|
||||
// Then check probing
|
||||
self.check_probe(observed)
|
||||
}
|
||||
|
||||
fn force_profile(&mut self, profile: QualityProfile) {
|
||||
@@ -331,25 +513,33 @@ mod tests {
|
||||
}
|
||||
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||
|
||||
// 9 good reports — not enough
|
||||
let good = make_report(2.0, 100);
|
||||
for _ in 0..9 {
|
||||
// 4 good reports — not enough (threshold is 5)
|
||||
let good = make_report(0.5, 20); // studio-quality report
|
||||
for _ in 0..4 {
|
||||
assert!(ctrl.observe(&good).is_none());
|
||||
}
|
||||
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||
|
||||
// 10th good report triggers upgrade (one step: Catastrophic → Degraded)
|
||||
// 5th good report triggers upgrade (one step: Catastrophic → Degraded)
|
||||
let result = ctrl.observe(&good);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(ctrl.tier(), Tier::Degraded);
|
||||
|
||||
// Need another 10 to go from Degraded → Good
|
||||
for _ in 0..9 {
|
||||
// Another 5 to go from Degraded → Good
|
||||
for _ in 0..4 {
|
||||
assert!(ctrl.observe(&good).is_none());
|
||||
}
|
||||
let result = ctrl.observe(&good);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(ctrl.tier(), Tier::Good);
|
||||
|
||||
// Studio upgrades need 10 consecutive — Good → Studio32k
|
||||
for _ in 0..9 {
|
||||
assert!(ctrl.observe(&good).is_none());
|
||||
}
|
||||
let result = ctrl.observe(&good);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(ctrl.tier(), Tier::Studio32k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -364,13 +554,78 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bwe_guard_blocks_upgrade_when_bandwidth_insufficient() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
|
||||
// Force to catastrophic
|
||||
let bad = make_report(50.0, 300);
|
||||
for _ in 0..3 {
|
||||
ctrl.observe(&bad);
|
||||
}
|
||||
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||
|
||||
// Attach a BWE with very low headroom.
|
||||
// Degraded tier needs 6kbps * 1.5 FEC = 9kbps → 130% = 11.7kbps.
|
||||
// Set target_send_bps ≈ 9_000 (below 11_700 threshold).
|
||||
let bwe = Arc::new(BandwidthEstimator::new(1000.0, 1.0, 100_000.0));
|
||||
bwe.update_from_path(1_000_000, 0, 10); // high cwnd
|
||||
bwe.update_from_peer(10_000); // low remb → target = 9_000
|
||||
ctrl.set_bandwidth_estimator(bwe.clone());
|
||||
|
||||
let good = make_report(0.5, 20);
|
||||
for _ in 0..5 {
|
||||
assert!(
|
||||
ctrl.observe(&good).is_none(),
|
||||
"upgrade should be blocked by low BWE"
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
ctrl.tier(),
|
||||
Tier::Catastrophic,
|
||||
"should remain at Catastrophic"
|
||||
);
|
||||
|
||||
// Raise BWE well above the 130% threshold
|
||||
bwe.update_from_peer(100_000); // target ≈ 90_000 bps
|
||||
|
||||
// Counter was reset, need another 5 good reports
|
||||
for _ in 0..4 {
|
||||
assert!(ctrl.observe(&good).is_none());
|
||||
}
|
||||
let result = ctrl.observe(&good);
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"upgrade should proceed with sufficient BWE"
|
||||
);
|
||||
assert_eq!(ctrl.tier(), Tier::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_classification() {
|
||||
assert_eq!(Tier::classify(&make_report(5.0, 200)), Tier::Good);
|
||||
assert_eq!(Tier::classify(&make_report(15.0, 200)), Tier::Degraded);
|
||||
assert_eq!(Tier::classify(&make_report(5.0, 500)), Tier::Degraded);
|
||||
assert_eq!(Tier::classify(&make_report(50.0, 200)), Tier::Catastrophic);
|
||||
assert_eq!(Tier::classify(&make_report(5.0, 700)), Tier::Catastrophic);
|
||||
// Studio tiers
|
||||
assert_eq!(Tier::classify(&make_report(0.5, 20)), Tier::Studio64k);
|
||||
assert_eq!(Tier::classify(&make_report(0.5, 40)), Tier::Studio48k);
|
||||
assert_eq!(Tier::classify(&make_report(1.5, 60)), Tier::Studio32k);
|
||||
// Good/Degraded/Catastrophic
|
||||
assert_eq!(Tier::classify(&make_report(3.0, 90)), Tier::Good);
|
||||
assert_eq!(Tier::classify(&make_report(6.0, 120)), Tier::Degraded);
|
||||
assert_eq!(Tier::classify(&make_report(16.0, 120)), Tier::Catastrophic);
|
||||
assert_eq!(Tier::classify(&make_report(5.0, 200)), Tier::Catastrophic);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn studio_tier_boundaries() {
|
||||
// loss < 1% AND RTT < 30ms → Studio64k
|
||||
assert_eq!(Tier::classify(&make_report(0.9, 28)), Tier::Studio64k);
|
||||
// loss < 1% AND RTT 30-49ms → Studio48k
|
||||
assert_eq!(Tier::classify(&make_report(0.9, 32)), Tier::Studio48k);
|
||||
// loss < 2% AND RTT < 80ms → Studio32k (but loss >= 1%)
|
||||
assert_eq!(Tier::classify(&make_report(1.5, 40)), Tier::Studio32k);
|
||||
// loss >= 2% → Good (use 2.5 to survive u8 quantization)
|
||||
assert_eq!(Tier::classify(&make_report(2.5, 40)), Tier::Good);
|
||||
// RTT 80ms → Good
|
||||
assert_eq!(Tier::classify(&make_report(0.5, 80)), Tier::Good);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
@@ -379,8 +634,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn cellular_tighter_thresholds() {
|
||||
// 12% loss: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(12.0, 200);
|
||||
// 9% loss: Degraded on both WiFi (>=5%) and cellular (>=8%)
|
||||
let report = make_report(9.0, 80);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Degraded
|
||||
@@ -390,22 +645,22 @@ mod tests {
|
||||
Tier::Degraded
|
||||
);
|
||||
|
||||
// 9% loss: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(9.0, 200);
|
||||
// 6% loss, low RTT: Degraded on WiFi (>=5%), Good on cellular (<8%)
|
||||
let report = make_report(6.0, 80);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Degraded
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Good
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Degraded
|
||||
);
|
||||
|
||||
// 30% loss: Degraded on WiFi, Catastrophic on cellular
|
||||
let report = make_report(30.0, 200);
|
||||
// 30% loss: Catastrophic on WiFi (>=15%), Catastrophic on cellular (>=25%)
|
||||
let report = make_report(30.0, 80);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Degraded
|
||||
Tier::Catastrophic
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::Cellular3g),
|
||||
@@ -415,15 +670,29 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn cellular_rtt_thresholds() {
|
||||
// RTT 350ms: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(2.0, 348); // rtt_4ms rounds so use 348
|
||||
// RTT 150ms: Degraded on WiFi (>=100ms), Good on cellular (<300ms and loss<8%)
|
||||
let report = make_report(2.0, 148);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Good
|
||||
Tier::Degraded
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Degraded
|
||||
Tier::Good
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cellular_no_studio_tiers() {
|
||||
// Even with perfect network, cellular stays at Good (no studio)
|
||||
let report = make_report(0.0, 10);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Good
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Studio64k
|
||||
);
|
||||
}
|
||||
|
||||
@@ -469,6 +738,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tier_downgrade() {
|
||||
assert_eq!(Tier::Studio64k.downgrade(), Some(Tier::Studio48k));
|
||||
assert_eq!(Tier::Studio48k.downgrade(), Some(Tier::Studio32k));
|
||||
assert_eq!(Tier::Studio32k.downgrade(), Some(Tier::Good));
|
||||
assert_eq!(Tier::Good.downgrade(), Some(Tier::Degraded));
|
||||
assert_eq!(Tier::Degraded.downgrade(), Some(Tier::Catastrophic));
|
||||
assert_eq!(Tier::Catastrophic.downgrade(), None);
|
||||
@@ -478,4 +750,103 @@ mod tests {
|
||||
fn network_context_default() {
|
||||
assert_eq!(NetworkContext::default(), NetworkContext::Unknown);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Bandwidth probing tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn probe_triggers_after_stable_period() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
let excellent = make_report(0.3, 20); // would classify as Studio64k
|
||||
|
||||
// Starts at Good. Fast-forward stability by setting stable_since directly.
|
||||
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||
|
||||
// One excellent report should trigger a probe (Good → Studio32k)
|
||||
let result = ctrl.observe(&excellent);
|
||||
assert!(result.is_some(), "should start probe after 30s stable");
|
||||
assert!(ctrl.probe.is_some(), "probe should be active");
|
||||
assert_eq!(ctrl.probe.as_ref().unwrap().target_tier, Tier::Studio32k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn probe_succeeds_after_window() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||
|
||||
let excellent = make_report(0.3, 20);
|
||||
|
||||
// Trigger probe start
|
||||
let result = ctrl.observe(&excellent);
|
||||
assert!(result.is_some());
|
||||
|
||||
// Simulate probe window elapsed by backdating started
|
||||
ctrl.probe.as_mut().unwrap().started =
|
||||
Instant::now() - Duration::from_secs(PROBE_DURATION_SECS);
|
||||
|
||||
// Next good report should finalize the probe
|
||||
let result = ctrl.observe(&excellent);
|
||||
assert!(result.is_some(), "probe should succeed");
|
||||
assert_eq!(ctrl.current_tier, Tier::Studio32k);
|
||||
assert!(ctrl.probe.is_none(), "probe should be cleared");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn probe_fails_on_bad_reports() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
// Put controller at Studio32k, pretend we've been stable
|
||||
ctrl.current_tier = Tier::Studio32k;
|
||||
ctrl.current_profile = Tier::Studio32k.profile();
|
||||
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||
|
||||
// Start a probe to Studio48k
|
||||
let excellent = make_report(0.3, 20);
|
||||
let result = ctrl.observe(&excellent);
|
||||
assert!(result.is_some()); // probe started
|
||||
assert_eq!(ctrl.probe.as_ref().unwrap().target_tier, Tier::Studio48k);
|
||||
|
||||
// Feed bad reports (loss too high for Studio48k)
|
||||
let degraded = make_report(3.0, 100);
|
||||
ctrl.observe(°raded); // first bad
|
||||
ctrl.observe(°raded); // second bad — exceeds PROBE_MAX_BAD (1)
|
||||
|
||||
// Probe should be cancelled
|
||||
assert!(
|
||||
ctrl.probe.is_none(),
|
||||
"probe should be cancelled after bad reports"
|
||||
);
|
||||
// Should still be at Studio32k (not upgraded)
|
||||
assert_eq!(ctrl.current_tier, Tier::Studio32k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_probe_on_cellular() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
ctrl.signal_network_change(NetworkContext::CellularLte);
|
||||
ctrl.current_tier = Tier::Good;
|
||||
ctrl.current_profile = Tier::Good.profile();
|
||||
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(60));
|
||||
|
||||
let good = make_report(0.5, 40);
|
||||
let result = ctrl.observe(&good);
|
||||
// Should NOT probe on cellular
|
||||
assert!(ctrl.probe.is_none(), "should not probe on cellular");
|
||||
assert!(result.is_none() || ctrl.current_tier == Tier::Good);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_probe_at_highest_tier() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
ctrl.current_tier = Tier::Studio64k;
|
||||
ctrl.current_profile = Tier::Studio64k.profile();
|
||||
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(60));
|
||||
|
||||
let excellent = make_report(0.1, 10);
|
||||
let result = ctrl.observe(&excellent);
|
||||
assert!(
|
||||
result.is_none(),
|
||||
"should not probe when already at Studio64k"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,18 +61,34 @@ pub trait FecEncoder: Send + Sync {
|
||||
/// Add a source symbol (one audio frame) to the current block.
|
||||
fn add_source_symbol(&mut self, data: &[u8]) -> Result<(), FecError>;
|
||||
|
||||
/// Add a source symbol and mark whether it belongs to a keyframe.
|
||||
///
|
||||
/// When the block contains at least one keyframe source symbol,
|
||||
/// [`generate_repair`] uses the configured keyframe ratio instead of the
|
||||
/// nominal ratio.
|
||||
///
|
||||
/// Default implementation delegates to [`add_source_symbol`] and ignores
|
||||
/// the keyframe flag.
|
||||
fn add_source_symbol_with_keyframe(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
_is_keyframe: bool,
|
||||
) -> Result<(), FecError> {
|
||||
self.add_source_symbol(data)
|
||||
}
|
||||
|
||||
/// Generate repair symbols for the current block.
|
||||
///
|
||||
/// `ratio` is the repair overhead (e.g., 0.5 = 50% more symbols than source).
|
||||
/// Returns `(fec_symbol_index, repair_data)` pairs.
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError>;
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError>;
|
||||
|
||||
/// Finalize the current block and start a new one.
|
||||
/// Returns the block ID of the finalized block.
|
||||
fn finalize_block(&mut self) -> Result<u8, FecError>;
|
||||
fn finalize_block(&mut self) -> Result<u16, FecError>;
|
||||
|
||||
/// Current block ID being built.
|
||||
fn current_block_id(&self) -> u8;
|
||||
fn current_block_id(&self) -> u16;
|
||||
|
||||
/// Number of source symbols in the current block.
|
||||
fn current_block_size(&self) -> usize;
|
||||
@@ -83,8 +99,8 @@ pub trait FecDecoder: Send + Sync {
|
||||
/// Feed a received symbol (source or repair) into the decoder.
|
||||
fn add_symbol(
|
||||
&mut self,
|
||||
block_id: u8,
|
||||
symbol_index: u8,
|
||||
block_id: u16,
|
||||
symbol_index: u16,
|
||||
is_repair: bool,
|
||||
data: &[u8],
|
||||
) -> Result<(), FecError>;
|
||||
@@ -93,10 +109,10 @@ pub trait FecDecoder: Send + Sync {
|
||||
///
|
||||
/// Returns `None` if not yet decodable (insufficient symbols).
|
||||
/// Returns `Some(Vec<source_frames>)` on success.
|
||||
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError>;
|
||||
fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError>;
|
||||
|
||||
/// Drop state for blocks older than `block_id`.
|
||||
fn expire_before(&mut self, block_id: u8);
|
||||
fn expire_before(&mut self, block_id: u16);
|
||||
}
|
||||
|
||||
// ─── Crypto Traits ───────────────────────────────────────────────────────────
|
||||
|
||||
@@ -20,6 +20,7 @@ bytes = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
toml = "0.8"
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde_json = "1"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
@@ -28,6 +29,7 @@ prometheus = "0.13"
|
||||
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "ws"] }
|
||||
tower-http = { version = "0.6", features = ["fs"] }
|
||||
futures-util = "0.3"
|
||||
dashmap = "6"
|
||||
dirs = "6"
|
||||
sha2 = { workspace = true }
|
||||
chrono = "0.4"
|
||||
|
||||
@@ -7,9 +7,7 @@ fn main() {
|
||||
.output();
|
||||
|
||||
let hash = match output {
|
||||
Ok(o) if o.status.success() => {
|
||||
String::from_utf8_lossy(&o.stdout).trim().to_string()
|
||||
}
|
||||
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout).trim().to_string(),
|
||||
_ => "unknown".to_string(),
|
||||
};
|
||||
|
||||
|
||||
467
crates/wzp-relay/src/audio_scorer.rs
Normal file
467
crates/wzp-relay/src/audio_scorer.rs
Normal file
@@ -0,0 +1,467 @@
|
||||
//! Tier F audio scorer — behavioural entropy detection for abuse mitigation.
|
||||
//!
|
||||
//! Computes a `legitimacy ∈ [0, 1]` score over a 10–30 s observation window.
|
||||
//! Features: IAT CoV, payload-size bimodality, silence fraction, bitrate
|
||||
//! deviation, and Q-flag cadence.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use wzp_proto::{CodecId, MediaHeader, MediaType};
|
||||
|
||||
use crate::verdict::Verdict;
|
||||
|
||||
/// Maximum samples kept in rolling windows.
|
||||
const MAX_IAT_SAMPLES: usize = 200;
|
||||
const MAX_SIZE_SAMPLES: usize = 200;
|
||||
const MAX_Q_INTERVALS: usize = 32;
|
||||
|
||||
/// Silence threshold: payload below this many bytes is treated as silence / CN.
|
||||
const SILENCE_SIZE_THRESHOLD: usize = 16;
|
||||
|
||||
/// Observation window for bitrate tracking.
|
||||
const BITRATE_WINDOW_SECS: u64 = 30;
|
||||
|
||||
// Number of payload-size histogram bins.
|
||||
// (SIZE_BINS reserved for future histogram-based bimodality)
|
||||
|
||||
/// Audio-specific behavioural scorer (Tier F).
|
||||
pub struct AudioScorer {
|
||||
/// Rolling inter-arrival times.
|
||||
iat_samples: VecDeque<Duration>,
|
||||
last_arrival: Option<Instant>,
|
||||
|
||||
/// Rolling payload sizes.
|
||||
size_samples: VecDeque<usize>,
|
||||
|
||||
/// Count of packets below silence threshold.
|
||||
silence_packets: u32,
|
||||
/// Total packets observed in current window.
|
||||
total_packets: u32,
|
||||
|
||||
/// Bitrate window.
|
||||
window_start: Instant,
|
||||
window_bytes: u64,
|
||||
|
||||
/// Q-flag arrival intervals.
|
||||
q_intervals: VecDeque<Duration>,
|
||||
last_q_flag: Option<Instant>,
|
||||
|
||||
/// Codec declared at first packet (used for nominal bitrate baseline).
|
||||
declared_codec: Option<CodecId>,
|
||||
}
|
||||
|
||||
impl AudioScorer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
iat_samples: VecDeque::with_capacity(MAX_IAT_SAMPLES),
|
||||
last_arrival: None,
|
||||
size_samples: VecDeque::with_capacity(MAX_SIZE_SAMPLES),
|
||||
silence_packets: 0,
|
||||
total_packets: 0,
|
||||
window_start: Instant::now(),
|
||||
window_bytes: 0,
|
||||
q_intervals: VecDeque::with_capacity(MAX_Q_INTERVALS),
|
||||
last_q_flag: None,
|
||||
declared_codec: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed one packet into the scorer.
|
||||
pub fn observe(&mut self, header: &MediaHeader, payload_len: usize, now: Instant) {
|
||||
// Ignore non-audio traffic.
|
||||
if header.media_type != MediaType::Audio {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.declared_codec.is_none() {
|
||||
self.declared_codec = Some(header.codec_id);
|
||||
}
|
||||
|
||||
// IAT
|
||||
if let Some(last) = self.last_arrival {
|
||||
let iat = now.saturating_duration_since(last);
|
||||
self.iat_samples.push_back(iat);
|
||||
if self.iat_samples.len() > MAX_IAT_SAMPLES {
|
||||
self.iat_samples.pop_front();
|
||||
}
|
||||
}
|
||||
self.last_arrival = Some(now);
|
||||
|
||||
// Payload size
|
||||
self.size_samples.push_back(payload_len);
|
||||
if self.size_samples.len() > MAX_SIZE_SAMPLES {
|
||||
self.size_samples.pop_front();
|
||||
}
|
||||
|
||||
// Silence fraction
|
||||
self.total_packets += 1;
|
||||
if payload_len <= SILENCE_SIZE_THRESHOLD {
|
||||
self.silence_packets += 1;
|
||||
}
|
||||
|
||||
// Bitrate window
|
||||
if now.duration_since(self.window_start) >= Duration::from_secs(BITRATE_WINDOW_SECS) {
|
||||
self.window_start = now;
|
||||
self.window_bytes = 0;
|
||||
}
|
||||
self.window_bytes += (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||
|
||||
// Q-flag cadence
|
||||
if header.has_quality() {
|
||||
if let Some(last) = self.last_q_flag {
|
||||
let interval = now.saturating_duration_since(last);
|
||||
self.q_intervals.push_back(interval);
|
||||
if self.q_intervals.len() > MAX_Q_INTERVALS {
|
||||
self.q_intervals.pop_front();
|
||||
}
|
||||
}
|
||||
self.last_q_flag = Some(now);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute legitimacy score ∈ [0, 1].
|
||||
///
|
||||
/// Higher = more legitimate. Returns `None` when insufficient samples
|
||||
/// have been collected (< 20 packets).
|
||||
pub fn legitimacy(&self) -> Option<f32> {
|
||||
if self.total_packets < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut score = 1.0f32;
|
||||
|
||||
// 1. IAT CoV penalty
|
||||
if let Some(cov) = self.iat_cov() {
|
||||
if cov > 0.4 {
|
||||
let penalty = ((cov - 0.4) / 0.6).min(1.0) * 0.25;
|
||||
score -= penalty as f32;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Silence fraction penalty
|
||||
let silence_fraction = self.silence_fraction();
|
||||
if silence_fraction < 0.02 {
|
||||
let penalty = ((0.02 - silence_fraction) / 0.02).min(1.0) * 0.25;
|
||||
score -= penalty as f32;
|
||||
} else if silence_fraction > 0.60 {
|
||||
// Too much silence can also be suspicious (stuffed payloads)
|
||||
let penalty = ((silence_fraction - 0.60) / 0.40).min(1.0) * 0.15;
|
||||
score -= penalty as f32;
|
||||
}
|
||||
|
||||
// 3. Bitrate deviation penalty
|
||||
if let Some(ratio) = self.bitrate_ratio() {
|
||||
if ratio > 1.20 {
|
||||
let penalty = ((ratio - 1.20) / 0.80).min(1.0) * 0.25;
|
||||
score -= penalty as f32;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Q-flag cadence penalty
|
||||
if let Some(cv) = self.q_flag_cv() {
|
||||
// High variability in Q-flag spacing = suspicious
|
||||
if cv > 0.5 {
|
||||
let penalty = ((cv - 0.5) / 0.5).min(1.0) * 0.15;
|
||||
score -= penalty as f32;
|
||||
}
|
||||
} else {
|
||||
// No Q flags seen at all — mildly suspicious after many packets
|
||||
if self.total_packets > 100 {
|
||||
score -= 0.10;
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Payload-size bimodality bonus/penalty
|
||||
if let Some(bimodality) = self.size_bimodality() {
|
||||
// Bimodality score: 0 = unimodal, 1 = strongly bimodal
|
||||
// Legitimate audio is bimodal (speech + silence)
|
||||
if bimodality < 0.2 {
|
||||
score -= 0.10;
|
||||
}
|
||||
}
|
||||
|
||||
Some(score.clamp(0.0, 1.0))
|
||||
}
|
||||
|
||||
/// Map legitimacy score to a [`Verdict`].
|
||||
pub fn verdict(&self) -> Option<Verdict> {
|
||||
self.legitimacy().map(|s| {
|
||||
if s >= 0.7 {
|
||||
Verdict::Legitimate
|
||||
} else if s >= 0.3 {
|
||||
Verdict::Suspect
|
||||
} else {
|
||||
Verdict::Abusive
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Feature extractors
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
/// Coefficient of variation of inter-arrival times.
|
||||
fn iat_cov(&self) -> Option<f64> {
|
||||
if self.iat_samples.len() < 10 {
|
||||
return None;
|
||||
}
|
||||
let mean = self
|
||||
.iat_samples
|
||||
.iter()
|
||||
.map(|d| d.as_secs_f64())
|
||||
.sum::<f64>()
|
||||
/ self.iat_samples.len() as f64;
|
||||
if mean == 0.0 {
|
||||
return None;
|
||||
}
|
||||
let variance = self
|
||||
.iat_samples
|
||||
.iter()
|
||||
.map(|d| {
|
||||
let diff = d.as_secs_f64() - mean;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f64>()
|
||||
/ self.iat_samples.len() as f64;
|
||||
let std = variance.sqrt();
|
||||
Some(std / mean)
|
||||
}
|
||||
|
||||
/// Fraction of packets that are silence / comfort-noise sized.
|
||||
fn silence_fraction(&self) -> f64 {
|
||||
if self.total_packets == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.silence_packets as f64 / self.total_packets as f64
|
||||
}
|
||||
|
||||
/// Ratio of observed bitrate to nominal bitrate over the 30 s window.
|
||||
fn bitrate_ratio(&self) -> Option<f64> {
|
||||
let codec = self.declared_codec?;
|
||||
let nominal_bps = codec.bitrate_bps() as f64;
|
||||
if nominal_bps == 0.0 {
|
||||
return None;
|
||||
}
|
||||
let observed_bps = self.window_bytes as f64 * 8.0 / BITRATE_WINDOW_SECS as f64;
|
||||
Some(observed_bps / nominal_bps)
|
||||
}
|
||||
|
||||
/// Coefficient of variation of Q-flag intervals.
|
||||
fn q_flag_cv(&self) -> Option<f64> {
|
||||
if self.q_intervals.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let mean = self
|
||||
.q_intervals
|
||||
.iter()
|
||||
.map(|d| d.as_secs_f64())
|
||||
.sum::<f64>()
|
||||
/ self.q_intervals.len() as f64;
|
||||
if mean == 0.0 {
|
||||
return None;
|
||||
}
|
||||
let variance = self
|
||||
.q_intervals
|
||||
.iter()
|
||||
.map(|d| {
|
||||
let diff = d.as_secs_f64() - mean;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f64>()
|
||||
/ self.q_intervals.len() as f64;
|
||||
let std = variance.sqrt();
|
||||
Some(std / mean)
|
||||
}
|
||||
|
||||
/// Simple bimodality score based on a 2-bin histogram.
|
||||
///
|
||||
/// Splits payload sizes into "small" (≤ threshold) and "large" bins.
|
||||
/// Returns a score in [0, 1] where 1 = strongly bimodal.
|
||||
fn size_bimodality(&self) -> Option<f64> {
|
||||
if self.size_samples.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
let small = self
|
||||
.size_samples
|
||||
.iter()
|
||||
.filter(|&&s| s <= SILENCE_SIZE_THRESHOLD)
|
||||
.count();
|
||||
let large = self.size_samples.len() - small;
|
||||
let total = self.size_samples.len() as f64;
|
||||
let p_small = small as f64 / total;
|
||||
let _p_large = large as f64 / total;
|
||||
// Max bimodality when both bins are equally populated (~0.5 each)
|
||||
let bimodality = 1.0 - (p_small - 0.5).abs() * 2.0;
|
||||
Some(bimodality)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn audio_header(payload_len: usize, has_quality: bool) -> MediaHeader {
|
||||
MediaHeader {
|
||||
version: 2,
|
||||
flags: if has_quality { 0x40 } else { 0 },
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_ignores_video() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let mut h = audio_header(100, false);
|
||||
h.media_type = MediaType::Video;
|
||||
scorer.observe(&h, 100, Instant::now());
|
||||
assert_eq!(scorer.total_packets, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_counts_packets() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
for i in 0..25 {
|
||||
let h = audio_header(100, false);
|
||||
scorer.observe(&h, 100, Instant::now() + Duration::from_millis(i * 20));
|
||||
}
|
||||
assert_eq!(scorer.total_packets, 25);
|
||||
assert!(scorer.legitimacy().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_legitimate_traffic() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
// Simulate 200 packets of legitimate audio:
|
||||
// ~20 ms IAT, mixed speech (100 B) and silence (8 B), periodic Q flags.
|
||||
for i in 0..200 {
|
||||
let payload = if i % 3 == 0 { 8 } else { 100 };
|
||||
let has_q = i % 10 == 0;
|
||||
let h = audio_header(payload, has_q);
|
||||
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let leg = scorer.legitimacy().unwrap();
|
||||
assert!(
|
||||
leg >= 0.7,
|
||||
"legitimate traffic should score ≥ 0.7, got {leg}"
|
||||
);
|
||||
assert_eq!(scorer.verdict(), Some(Verdict::Legitimate));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_abusive_uniform_iat() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
// Uniform IAT (no jitter), all same size, no Q flags — tunnel-like
|
||||
for i in 0..200 {
|
||||
let h = audio_header(200, false);
|
||||
scorer.observe(&h, 200, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let leg = scorer.legitimacy().unwrap();
|
||||
assert!(
|
||||
leg < 0.6,
|
||||
"uniform tunnel-like traffic should score < 0.6, got {leg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_abusive_no_silence() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
// No silence packets at all, very regular IAT
|
||||
for i in 0..200 {
|
||||
let h = audio_header(150, false);
|
||||
scorer.observe(&h, 150, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let leg = scorer.legitimacy().unwrap();
|
||||
assert!(
|
||||
leg < 0.6,
|
||||
"no-silence traffic should score < 0.6, got {leg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn audio_scorer_insufficient_samples() {
|
||||
let scorer = AudioScorer::new();
|
||||
assert_eq!(scorer.legitimacy(), None);
|
||||
assert_eq!(scorer.verdict(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silence_fraction_computed_correctly() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
for i in 0..100 {
|
||||
let payload = if i < 30 { 8 } else { 100 };
|
||||
let h = audio_header(payload, false);
|
||||
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
assert!((scorer.silence_fraction() - 0.30).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bitrate_ratio_saturates_when_no_codec() {
|
||||
let scorer = AudioScorer::new();
|
||||
assert_eq!(scorer.bitrate_ratio(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn q_flag_cv_regular_spacing() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
for i in 0..50 {
|
||||
let has_q = i % 5 == 0;
|
||||
let h = audio_header(100, has_q);
|
||||
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let cv = scorer.q_flag_cv().unwrap();
|
||||
assert!(
|
||||
cv < 0.1,
|
||||
"regular Q-flag spacing should have CV < 0.1, got {cv}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn size_bimodality_for_mixed_traffic() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
for i in 0..100 {
|
||||
let payload = if i % 2 == 0 { 8 } else { 120 };
|
||||
let h = audio_header(payload, false);
|
||||
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let bim = scorer.size_bimodality().unwrap();
|
||||
assert!(
|
||||
bim > 0.8,
|
||||
"perfectly mixed small/large should be highly bimodal, got {bim}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn size_bimodality_for_uniform_traffic() {
|
||||
let mut scorer = AudioScorer::new();
|
||||
let base = Instant::now();
|
||||
for i in 0..100 {
|
||||
let h = audio_header(100, false);
|
||||
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
|
||||
}
|
||||
let bim = scorer.size_bimodality().unwrap();
|
||||
assert!(
|
||||
bim < 0.3,
|
||||
"uniform size traffic should be unimodal, got {bim}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,7 @@ pub struct AuthenticatedClient {
|
||||
///
|
||||
/// Calls `POST {auth_url}` with `{ "token": "..." }`.
|
||||
/// Returns the client identity if valid, or an error string.
|
||||
pub async fn validate_token(
|
||||
auth_url: &str,
|
||||
token: &str,
|
||||
) -> Result<AuthenticatedClient, String> {
|
||||
pub async fn validate_token(auth_url: &str, token: &str) -> Result<AuthenticatedClient, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
|
||||
@@ -61,6 +61,13 @@ pub struct DirectCall {
|
||||
/// interface addresses from the `DirectCallAnswer`. Cross-
|
||||
/// wired into the caller's `CallSetup.peer_local_addrs`.
|
||||
pub callee_local_addrs: Vec<String>,
|
||||
/// Phase 8 (Tailscale-inspired): caller's port-mapped
|
||||
/// external address from NAT-PMP/PCP/UPnP. Cross-wired
|
||||
/// into callee's `CallSetup.peer_mapped_addr`.
|
||||
pub caller_mapped_addr: Option<String>,
|
||||
/// Phase 8: callee's port-mapped external address.
|
||||
/// Cross-wired into caller's `CallSetup.peer_mapped_addr`.
|
||||
pub callee_mapped_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// Registry of active direct calls.
|
||||
@@ -76,7 +83,12 @@ impl CallRegistry {
|
||||
}
|
||||
|
||||
/// Create a new pending call. Returns the call_id.
|
||||
pub fn create_call(&mut self, call_id: String, caller_fp: String, callee_fp: String) -> &DirectCall {
|
||||
pub fn create_call(
|
||||
&mut self,
|
||||
call_id: String,
|
||||
caller_fp: String,
|
||||
callee_fp: String,
|
||||
) -> &DirectCall {
|
||||
let call = DirectCall {
|
||||
call_id: call_id.clone(),
|
||||
caller_fingerprint: caller_fp,
|
||||
@@ -92,6 +104,8 @@ impl CallRegistry {
|
||||
peer_relay_fp: None,
|
||||
caller_local_addrs: Vec::new(),
|
||||
callee_local_addrs: Vec::new(),
|
||||
caller_mapped_addr: None,
|
||||
callee_mapped_addr: None,
|
||||
};
|
||||
self.calls.insert(call_id.clone(), call);
|
||||
self.calls.get(&call_id).unwrap()
|
||||
@@ -142,6 +156,22 @@ impl CallRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase 8: stash the caller's port-mapped address from
|
||||
/// the `DirectCallOffer`.
|
||||
pub fn set_caller_mapped_addr(&mut self, call_id: &str, addr: Option<String>) {
|
||||
if let Some(call) = self.calls.get_mut(call_id) {
|
||||
call.caller_mapped_addr = addr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase 8: stash the callee's port-mapped address from
|
||||
/// the `DirectCallAnswer`.
|
||||
pub fn set_callee_mapped_addr(&mut self, call_id: &str, addr: Option<String>) {
|
||||
if let Some(call) = self.calls.get_mut(call_id) {
|
||||
call.callee_mapped_addr = addr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a call by ID.
|
||||
pub fn get(&self, call_id: &str) -> Option<&DirectCall> {
|
||||
self.calls.get(call_id)
|
||||
@@ -164,7 +194,12 @@ impl CallRegistry {
|
||||
}
|
||||
|
||||
/// Transition to Active state.
|
||||
pub fn set_active(&mut self, call_id: &str, mode: wzp_proto::CallAcceptMode, room: String) -> bool {
|
||||
pub fn set_active(
|
||||
&mut self,
|
||||
call_id: &str,
|
||||
mode: wzp_proto::CallAcceptMode,
|
||||
room: String,
|
||||
) -> bool {
|
||||
if let Some(call) = self.calls.get_mut(call_id) {
|
||||
if call.state == DirectCallState::Pending || call.state == DirectCallState::Ringing {
|
||||
call.state = DirectCallState::Active;
|
||||
@@ -188,7 +223,8 @@ impl CallRegistry {
|
||||
|
||||
/// Find active/pending calls involving a fingerprint.
|
||||
pub fn calls_for_fingerprint(&self, fp: &str) -> Vec<&DirectCall> {
|
||||
self.calls.values()
|
||||
self.calls
|
||||
.values()
|
||||
.filter(|c| {
|
||||
c.state != DirectCallState::Ended
|
||||
&& (c.caller_fingerprint == fp || c.callee_fingerprint == fp)
|
||||
@@ -211,22 +247,25 @@ impl CallRegistry {
|
||||
/// Returns call IDs of expired calls.
|
||||
pub fn expire_stale(&mut self, timeout: Duration) -> Vec<DirectCall> {
|
||||
let now = Instant::now();
|
||||
let expired: Vec<String> = self.calls.iter()
|
||||
let expired: Vec<String> = self
|
||||
.calls
|
||||
.iter()
|
||||
.filter(|(_, c)| {
|
||||
c.state == DirectCallState::Pending
|
||||
&& now.duration_since(c.created_at) > timeout
|
||||
c.state == DirectCallState::Pending && now.duration_since(c.created_at) > timeout
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
|
||||
expired.into_iter()
|
||||
expired
|
||||
.into_iter()
|
||||
.filter_map(|id| self.calls.remove(&id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Number of active (non-ended) calls.
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.calls.values()
|
||||
self.calls
|
||||
.values()
|
||||
.filter(|c| c.state != DirectCallState::Ended)
|
||||
.count()
|
||||
}
|
||||
@@ -245,9 +284,16 @@ mod tests {
|
||||
assert!(reg.set_ringing("c1"));
|
||||
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Ringing);
|
||||
|
||||
assert!(reg.set_active("c1", wzp_proto::CallAcceptMode::AcceptGeneric, "_call:c1".into()));
|
||||
assert!(reg.set_active(
|
||||
"c1",
|
||||
wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||
"_call:c1".into()
|
||||
));
|
||||
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Active);
|
||||
assert_eq!(reg.get("c1").unwrap().room_name.as_deref(), Some("_call:c1"));
|
||||
assert_eq!(
|
||||
reg.get("c1").unwrap().room_name.as_deref(),
|
||||
Some("_call:c1")
|
||||
);
|
||||
|
||||
let ended = reg.end_call("c1").unwrap();
|
||||
assert_eq!(ended.state, DirectCallState::Ended);
|
||||
@@ -304,10 +350,7 @@ mod tests {
|
||||
// Both addrs are independently readable — the relay uses
|
||||
// them to cross-wire peer_direct_addr in CallSetup.
|
||||
let c = reg.get("c1").unwrap();
|
||||
assert_eq!(
|
||||
c.caller_reflexive_addr.as_deref(),
|
||||
Some("192.0.2.1:4433")
|
||||
);
|
||||
assert_eq!(c.caller_reflexive_addr.as_deref(), Some("192.0.2.1:4433"));
|
||||
assert_eq!(
|
||||
c.callee_reflexive_addr.as_deref(),
|
||||
Some("198.51.100.9:4433")
|
||||
@@ -340,6 +383,49 @@ mod tests {
|
||||
reg.set_peer_relay_fp("does-not-exist", Some("x".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_registry_stores_mapped_addrs() {
|
||||
let mut reg = CallRegistry::new();
|
||||
reg.create_call("c1".into(), "alice".into(), "bob".into());
|
||||
|
||||
// Default: both mapped addrs are None.
|
||||
let c = reg.get("c1").unwrap();
|
||||
assert!(c.caller_mapped_addr.is_none());
|
||||
assert!(c.callee_mapped_addr.is_none());
|
||||
|
||||
// Caller advertises its port-mapped addr via DirectCallOffer.
|
||||
reg.set_caller_mapped_addr("c1", Some("203.0.113.5:12345".into()));
|
||||
assert_eq!(
|
||||
reg.get("c1").unwrap().caller_mapped_addr.as_deref(),
|
||||
Some("203.0.113.5:12345")
|
||||
);
|
||||
|
||||
// Callee responds with its mapped addr.
|
||||
reg.set_callee_mapped_addr("c1", Some("198.51.100.9:54321".into()));
|
||||
assert_eq!(
|
||||
reg.get("c1").unwrap().callee_mapped_addr.as_deref(),
|
||||
Some("198.51.100.9:54321")
|
||||
);
|
||||
|
||||
// Both addrs readable — relay uses them to cross-wire
|
||||
// peer_mapped_addr in CallSetup.
|
||||
let c = reg.get("c1").unwrap();
|
||||
assert_eq!(c.caller_mapped_addr.as_deref(), Some("203.0.113.5:12345"));
|
||||
assert_eq!(c.callee_mapped_addr.as_deref(), Some("198.51.100.9:54321"));
|
||||
|
||||
// Setter on unknown call is a no-op.
|
||||
reg.set_caller_mapped_addr("nope", Some("x".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_registry_clearing_mapped_addr_works() {
|
||||
let mut reg = CallRegistry::new();
|
||||
reg.create_call("c1".into(), "alice".into(), "bob".into());
|
||||
reg.set_caller_mapped_addr("c1", Some("1.2.3.4:5".into()));
|
||||
reg.set_caller_mapped_addr("c1", None);
|
||||
assert!(reg.get("c1").unwrap().caller_mapped_addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_registry_clearing_reflex_addr_works() {
|
||||
// Passing None to the setter must clear a previously-set value
|
||||
|
||||
@@ -87,6 +87,14 @@ pub struct RelayConfig {
|
||||
/// Unlike [[peers]], no url is needed — the peer connects to us.
|
||||
#[serde(default)]
|
||||
pub trusted: Vec<TrustedConfig>,
|
||||
/// Phase 8: geographic region identifier (e.g., "us-east", "eu-west").
|
||||
/// Sent to clients in `RegisterPresenceAck.relay_region` so they can
|
||||
/// build a relay map for automatic selection.
|
||||
pub region: Option<String>,
|
||||
/// Phase 8: externally-advertised address for this relay. Used to
|
||||
/// populate `available_relays` in `RegisterPresenceAck`. If not set,
|
||||
/// `listen_addr` is used.
|
||||
pub advertised_addr: Option<SocketAddr>,
|
||||
/// Debug tap: log packet headers for matching rooms ("*" = all rooms).
|
||||
/// Activated via --debug-tap <room> or debug_tap = "room" in TOML.
|
||||
pub debug_tap: Option<String>,
|
||||
@@ -114,6 +122,8 @@ impl Default for RelayConfig {
|
||||
peers: Vec::new(),
|
||||
global_rooms: Vec::new(),
|
||||
trusted: Vec::new(),
|
||||
region: None,
|
||||
advertised_addr: None,
|
||||
debug_tap: None,
|
||||
event_log: None,
|
||||
}
|
||||
@@ -135,7 +145,10 @@ pub struct RelayInfo {
|
||||
}
|
||||
|
||||
/// Load config from path, or create a personalized example config if it doesn't exist.
|
||||
pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<RelayConfig, anyhow::Error> {
|
||||
pub fn load_or_create_config(
|
||||
path: &str,
|
||||
info: Option<&RelayInfo>,
|
||||
) -> Result<RelayConfig, anyhow::Error> {
|
||||
let p = std::path::Path::new(path);
|
||||
if p.exists() {
|
||||
return load_config(path);
|
||||
@@ -154,7 +167,9 @@ pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<Rel
|
||||
|
||||
/// Generate an example TOML config, personalized with this relay's info if available.
|
||||
fn generate_example_config(info: Option<&RelayInfo>) -> String {
|
||||
let listen = info.map(|i| i.listen_addr.as_str()).unwrap_or("0.0.0.0:4433");
|
||||
let listen = info
|
||||
.map(|i| i.listen_addr.as_str())
|
||||
.unwrap_or("0.0.0.0:4433");
|
||||
let peer_example = if let Some(i) = info {
|
||||
let ip = i.public_ip.as_deref().unwrap_or("this-relay-ip");
|
||||
format!(
|
||||
|
||||
544
crates/wzp-relay/src/conformance.rs
Normal file
544
crates/wzp-relay/src/conformance.rs
Normal file
@@ -0,0 +1,544 @@
|
||||
//! Relay conformance metering — Tier A/B/C/D/E enforcement.
|
||||
//!
|
||||
//! Each participant gets a [`ConformanceMeter`] that tracks per-second
|
||||
//! traffic against the declared codec's nominal bitrate ceiling.
|
||||
//! Violations are logged and counted but do **not** drop packets
|
||||
//! (observe-only mode).
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
use wzp_proto::{CodecId, MediaHeader};
|
||||
|
||||
/// Rolling window size for timestamp-drift detection (Tier C).
|
||||
const DRIFT_WINDOW_SIZE: usize = 200;
|
||||
|
||||
/// Kinds of conformance violation detected by the relay.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Violation {
|
||||
/// Cumulative bitrate in the current 1 s window exceeds the Tier A ceiling.
|
||||
BitrateExceeded,
|
||||
/// Packet rate exceeds the per-codec safety limit (Tier B).
|
||||
PacketRateExceeded,
|
||||
/// Timestamp jumped backwards or forwards suspiciously (Tier C).
|
||||
TimestampDrift,
|
||||
/// Sustained payload size exceeds 2× the typical bound for the declared codec (Tier D).
|
||||
PayloadSizeExceeded,
|
||||
/// Per-session token-bucket rate cap exceeded (Tier E).
|
||||
RateCapExceeded,
|
||||
}
|
||||
|
||||
/// Error type returned when a [`TokenBucket`] does not hold enough tokens.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct TokenExhausted;
|
||||
|
||||
/// Simple token bucket for per-session rate capping (Tier E).
|
||||
///
|
||||
/// Tokens represent bytes. The bucket refills at `refill_per_sec` bytes per
|
||||
/// second, up to `capacity`. A packet is allowed only if the bucket holds
|
||||
/// enough tokens for its size.
|
||||
pub struct TokenBucket {
|
||||
capacity: u64,
|
||||
tokens: f64,
|
||||
refill_per_sec: u64,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new bucket with the given byte capacity and refill rate.
|
||||
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
|
||||
Self {
|
||||
capacity,
|
||||
tokens: capacity as f64,
|
||||
refill_per_sec,
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-session audio cap: 256 kbps with 30 s @ 2× burst.
|
||||
/// Capacity = 30 s × 64 KB/s = 1_920_000 bytes.
|
||||
pub fn for_audio_session() -> Self {
|
||||
let refill_per_sec = 256_000 / 8; // 32_000 bytes/sec
|
||||
let capacity = refill_per_sec * 30 * 2; // 1_920_000 bytes
|
||||
Self::new(capacity, refill_per_sec)
|
||||
}
|
||||
|
||||
/// Attempt to consume `bytes` from the bucket.
|
||||
///
|
||||
/// Refills based on elapsed time since the last call, then deducts the
|
||||
/// cost. Returns `Ok(())` if enough tokens were available,
|
||||
/// `Err(TokenExhausted)` otherwise.
|
||||
pub fn try_consume(&mut self, bytes: u64, now: Instant) -> Result<(), TokenExhausted> {
|
||||
let elapsed = now.duration_since(self.last_refill);
|
||||
self.last_refill = now;
|
||||
self.tokens += elapsed.as_secs_f64() * self.refill_per_sec as f64;
|
||||
if self.tokens > self.capacity as f64 {
|
||||
self.tokens = self.capacity as f64;
|
||||
}
|
||||
if self.tokens >= bytes as f64 {
|
||||
self.tokens -= bytes as f64;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TokenExhausted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-participant traffic conformance meter.
|
||||
pub struct ConformanceMeter {
|
||||
window_start: Instant,
|
||||
bytes_in_window: u64,
|
||||
packets_in_window: u64,
|
||||
/// Rolling (seq, timestamp) pairs for drift detection.
|
||||
drift_window: VecDeque<(u32, u32)>,
|
||||
/// EWMA of payload size for Tier D sanity checks.
|
||||
ewma_payload_size: f64,
|
||||
/// Optional token bucket for Tier E per-session rate cap.
|
||||
token_bucket: Option<TokenBucket>,
|
||||
}
|
||||
|
||||
impl ConformanceMeter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
window_start: Instant::now(),
|
||||
bytes_in_window: 0,
|
||||
packets_in_window: 0,
|
||||
drift_window: VecDeque::with_capacity(DRIFT_WINDOW_SIZE),
|
||||
ewma_payload_size: 0.0,
|
||||
token_bucket: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a meter with a Tier E token bucket for per-session rate capping.
|
||||
pub fn with_token_bucket(bucket: TokenBucket) -> Self {
|
||||
let mut meter = Self::new();
|
||||
meter.token_bucket = Some(bucket);
|
||||
meter
|
||||
}
|
||||
|
||||
/// Inspect an incoming media packet and accumulate it against the
|
||||
/// current 1-second window. Returns [`Err(Violation)`] when a limit
|
||||
/// is crossed.
|
||||
pub fn observe(
|
||||
&mut self,
|
||||
header: &MediaHeader,
|
||||
payload_len: usize,
|
||||
now: Instant,
|
||||
) -> Result<(), Violation> {
|
||||
// Roll the window forward if a second has elapsed.
|
||||
if now.duration_since(self.window_start) >= Duration::from_secs(1) {
|
||||
self.window_start = now;
|
||||
self.bytes_in_window = 0;
|
||||
self.packets_in_window = 0;
|
||||
}
|
||||
|
||||
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||
self.bytes_in_window += packet_size;
|
||||
self.packets_in_window += 1;
|
||||
|
||||
// Tier A — bitrate ceiling.
|
||||
let ceiling = ceiling_bps(header.codec_id);
|
||||
let max_bytes_per_sec = ceiling / 8;
|
||||
if self.bytes_in_window > max_bytes_per_sec {
|
||||
return Err(Violation::BitrateExceeded);
|
||||
}
|
||||
|
||||
// Tier B — packet-rate ceiling.
|
||||
let max_pps = max_pps(header.codec_id);
|
||||
let pps_threshold = (max_pps as f32 * 1.5) as u64;
|
||||
if self.packets_in_window > pps_threshold {
|
||||
return Err(Violation::PacketRateExceeded);
|
||||
}
|
||||
|
||||
// Tier C — timestamp drift.
|
||||
self.drift_window.push_back((header.seq, header.timestamp));
|
||||
if self.drift_window.len() > DRIFT_WINDOW_SIZE {
|
||||
self.drift_window.pop_front();
|
||||
}
|
||||
if self.drift_window.len() >= 2 {
|
||||
let (first_seq, first_ts) = self.drift_window.front().copied().unwrap();
|
||||
let (last_seq, last_ts) = self.drift_window.back().copied().unwrap();
|
||||
|
||||
let ds = last_seq.wrapping_sub(first_seq) as f64;
|
||||
let dt = last_ts.wrapping_sub(first_ts) as f64;
|
||||
|
||||
if ds > 0.0 {
|
||||
let avg_ms_per_packet = dt / ds;
|
||||
let frame_ms = header.codec_id.frame_duration_ms() as f64;
|
||||
let min_ratio = frame_ms * 0.5;
|
||||
let max_ratio = frame_ms * 2.0;
|
||||
if avg_ms_per_packet < min_ratio || avg_ms_per_packet > max_ratio {
|
||||
return Err(Violation::TimestampDrift);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tier D — payload-size sanity (EWMA).
|
||||
let alpha = 0.05; // ~20-packet smoothing
|
||||
self.ewma_payload_size =
|
||||
alpha * payload_len as f64 + (1.0 - alpha) * self.ewma_payload_size;
|
||||
let bound = payload_size_bound(header.codec_id);
|
||||
if self.ewma_payload_size > (bound * 2) as f64 {
|
||||
return Err(Violation::PayloadSizeExceeded);
|
||||
}
|
||||
|
||||
// Tier E — per-session token-bucket rate cap.
|
||||
if let Some(ref mut bucket) = self.token_bucket {
|
||||
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||
if bucket.try_consume(packet_size, now).is_err() {
|
||||
return Err(Violation::RateCapExceeded);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConformanceMeter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the Tier A bitrate ceiling for a given codec.
|
||||
///
|
||||
/// Formula:
|
||||
/// nominal_bitrate * 3 (FEC 2.0 overhead) * 115 / 100 (15% safety margin)
|
||||
/// with a floor of 2 kbps.
|
||||
pub fn ceiling_bps(codec: CodecId) -> u64 {
|
||||
let nominal = codec.bitrate_bps() as u64;
|
||||
(nominal * 3 * 115 / 100).max(2_000)
|
||||
}
|
||||
|
||||
/// Compute the Tier B packet-rate ceiling for a given codec.
|
||||
///
|
||||
/// Formula:
|
||||
/// 1000 / frame_duration_ms * 3 (FEC overhead factor)
|
||||
pub fn max_pps(codec: CodecId) -> u32 {
|
||||
let fd = codec.frame_duration_ms() as u32;
|
||||
if fd == 0 {
|
||||
return 0;
|
||||
}
|
||||
(1000 / fd) * 3
|
||||
}
|
||||
|
||||
/// Typical per-codec payload size bound in bytes (Tier D).
|
||||
///
|
||||
/// These are empirical upper bounds for a single audio frame at the codec's
|
||||
/// nominal configuration. The EWMA must not exceed 2× this value.
|
||||
pub fn payload_size_bound(codec: CodecId) -> usize {
|
||||
match codec {
|
||||
CodecId::Opus64k => 320,
|
||||
CodecId::Opus48k => 240,
|
||||
CodecId::Opus32k => 200,
|
||||
CodecId::Opus24k => 160,
|
||||
CodecId::Opus16k => 100,
|
||||
CodecId::Opus6k => 90,
|
||||
CodecId::Codec2_3200 => 30,
|
||||
CodecId::Codec2_1200 => 30,
|
||||
CodecId::ComfortNoise => 16,
|
||||
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => 1400,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::MediaType;
|
||||
|
||||
fn make_header(codec_id: CodecId) -> MediaHeader {
|
||||
MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_header_with_seq_ts(codec_id: CodecId, seq: u32, timestamp: u32) -> MediaHeader {
|
||||
MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id,
|
||||
seq,
|
||||
timestamp,
|
||||
fec_block: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bitrate_exceeded_for_opus24k() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let header = make_header(CodecId::Opus24k);
|
||||
|
||||
// Ceiling for Opus24k = 24_000 * 3 * 115 / 100 = 82_800 bps
|
||||
// = 10_350 bytes/sec. 1 MB/s = 125_000 bytes/packet will blow past
|
||||
// that in a single packet.
|
||||
let now = Instant::now();
|
||||
let result = meter.observe(&header, 1_000_000, now);
|
||||
assert_eq!(result, Err(Violation::BitrateExceeded));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_packets_stay_within_ceiling() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let header = make_header(CodecId::Opus24k);
|
||||
|
||||
// Ceiling = 82_800 bps = 10_350 bytes/sec.
|
||||
// Each packet = 16-byte header + 80 bytes = 96 bytes.
|
||||
// 100 packets = 9_600 bytes < 10_350.
|
||||
let now = Instant::now();
|
||||
for _ in 0..100 {
|
||||
assert!(meter.observe(&header, 80, now).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_resets_after_one_second() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let header = make_header(CodecId::Opus24k);
|
||||
|
||||
// Fill the window to just under the limit.
|
||||
// Use 300-byte payloads (under Tier D 2× bound of 320 for Opus24k).
|
||||
let t0 = Instant::now();
|
||||
for _ in 0..32 {
|
||||
assert!(meter.observe(&header, 300, t0).is_ok());
|
||||
}
|
||||
// 32 * (header wire size + 300) ≈ 32 * 316 = 10_112 bytes < 10_350
|
||||
|
||||
// Same packets 1.1 seconds later should be fine because the window
|
||||
// rolls over.
|
||||
let t1 = t0 + Duration::from_millis(1_100);
|
||||
for _ in 0..32 {
|
||||
assert!(meter.observe(&header, 300, t1).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ceiling_bps_floor() {
|
||||
// ComfortNoise has 0 nominal bitrate, so the floor kicks in.
|
||||
assert_eq!(ceiling_bps(CodecId::ComfortNoise), 2_000);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Tier B — packet rate
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn packet_rate_exceeded() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
// Opus24k: max_pps = 1000/20 * 3 = 150. Threshold = 150 * 1.5 = 225.
|
||||
let header = make_header(CodecId::Opus24k);
|
||||
let now = Instant::now();
|
||||
for _ in 0..225 {
|
||||
assert!(meter.observe(&header, 10, now).is_ok());
|
||||
}
|
||||
// 226th packet should trip the limit.
|
||||
assert_eq!(
|
||||
meter.observe(&header, 10, now),
|
||||
Err(Violation::PacketRateExceeded)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn packet_rate_within_limit() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
// Opus6k: max_pps = 1000/40 * 3 = 75. Threshold = 75 * 1.5 = 112.
|
||||
// Use 0-byte payload so bitrate ceiling (2_587 bytes/sec) is not the
|
||||
// limiting factor. 112 packets × 16 bytes = 1_792 bytes < 2_587.
|
||||
let header = make_header(CodecId::Opus6k);
|
||||
let now = Instant::now();
|
||||
for _ in 0..112 {
|
||||
assert!(meter.observe(&header, 0, now).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Tier C — timestamp drift
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn timestamp_drift_detected_when_too_fast() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
// Opus24k frame_duration = 20 ms.
|
||||
// Acceptable range: [10, 40] ms per packet.
|
||||
// Send packets with timestamp advancing by 5 ms each (too fast).
|
||||
let now = Instant::now();
|
||||
let mut drift_seen = false;
|
||||
for i in 0..200 {
|
||||
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 5);
|
||||
match meter.observe(&header, 10, now) {
|
||||
Ok(()) => {}
|
||||
Err(Violation::TimestampDrift) => drift_seen = true,
|
||||
Err(other) => panic!("unexpected violation: {other:?}"),
|
||||
}
|
||||
}
|
||||
assert!(drift_seen, "expected TimestampDrift to be detected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timestamp_drift_detected_when_too_slow() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
// Opus24k frame_duration = 20 ms.
|
||||
// Acceptable range: [10, 40] ms per packet.
|
||||
// Send packets with timestamp advancing by 50 ms each (too slow).
|
||||
let now = Instant::now();
|
||||
let mut drift_seen = false;
|
||||
for i in 0..200 {
|
||||
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 50);
|
||||
match meter.observe(&header, 10, now) {
|
||||
Ok(()) => {}
|
||||
Err(Violation::TimestampDrift) => drift_seen = true,
|
||||
Err(other) => panic!("unexpected violation: {other:?}"),
|
||||
}
|
||||
}
|
||||
assert!(drift_seen, "expected TimestampDrift to be detected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timestamp_normal_no_drift() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
// Opus24k frame_duration = 20 ms.
|
||||
// Send 200 packets with timestamp advancing by exactly 20 ms each.
|
||||
let now = Instant::now();
|
||||
for i in 0..200 {
|
||||
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 20);
|
||||
assert!(meter.observe(&header, 10, now).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timestamp_drift_not_checked_before_two_packets() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let now = Instant::now();
|
||||
// Single packet with wild timestamp — should not trigger drift.
|
||||
let header = make_header_with_seq_ts(CodecId::Opus24k, 0, 999_999);
|
||||
assert!(meter.observe(&header, 10, now).is_ok());
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Tier D — payload-size sanity
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn conformance_tier_d() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let header = make_header(CodecId::Codec2_1200);
|
||||
let now = Instant::now();
|
||||
|
||||
// Codec2_1200 bound = 30 bytes. 2× bound = 60 bytes.
|
||||
// Feed 1400-byte payloads — EWMA should cross 60 within a few packets.
|
||||
let mut flagged = false;
|
||||
for _ in 0..200 {
|
||||
if meter.observe(&header, 1400, now).is_err() {
|
||||
flagged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
flagged,
|
||||
"expected PayloadSizeExceeded for 1400-byte Codec2_1200 payloads"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn payload_size_normal_stays_within_bound() {
|
||||
let mut meter = ConformanceMeter::new();
|
||||
let header = make_header(CodecId::Opus24k);
|
||||
let now = Instant::now();
|
||||
|
||||
// Opus24k bound = 160 bytes. 2× bound = 320 bytes.
|
||||
// Feed 150-byte payloads — well within the 2× limit.
|
||||
// Limit to 10 packets so the 1-second bitrate window (10_350 bytes)
|
||||
// is not exhausted: 10 * (16 + 150) = 1_660 < 10_350.
|
||||
for _ in 0..10 {
|
||||
assert!(
|
||||
meter.observe(&header, 150, now).is_ok(),
|
||||
"150-byte Opus24k payloads should stay within Tier D limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Tier E — token-bucket rate cap
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn token_bucket_small_burst_ok() {
|
||||
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||
let now = Instant::now();
|
||||
// 50 KB burst fits inside 100 KB capacity.
|
||||
assert!(bucket.try_consume(50_000, now).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_bucket_large_burst_fails() {
|
||||
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||
let now = Instant::now();
|
||||
// 1 MB exceeds 100 KB capacity.
|
||||
assert!(bucket.try_consume(1_000_000, now).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_bucket_refills_over_time() {
|
||||
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||
let t0 = Instant::now();
|
||||
// Drain the bucket.
|
||||
assert!(bucket.try_consume(100_000, t0).is_ok());
|
||||
// Immediately try again — should fail.
|
||||
assert!(bucket.try_consume(10_000, t0).is_err());
|
||||
// Wait 1 second — bucket refills 32_000 bytes.
|
||||
let t1 = t0 + Duration::from_secs(1);
|
||||
assert!(bucket.try_consume(30_000, t1).is_ok());
|
||||
// 40_000 is more than the 32_000 refilled.
|
||||
assert!(bucket.try_consume(40_000, t1).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_bucket_sustained_rate_balanced() {
|
||||
let mut bucket = TokenBucket::new(1_000_000, 32_000);
|
||||
let t0 = Instant::now();
|
||||
// Send 32 KB every second for 5 seconds — exactly at refill rate.
|
||||
// The bucket should never empty because each second it refills
|
||||
// exactly what was consumed.
|
||||
for i in 0..5 {
|
||||
let t = t0 + Duration::from_secs(i);
|
||||
assert!(
|
||||
bucket.try_consume(32_000, t).is_ok(),
|
||||
"32 KB/s sustained should stay within bucket limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conformance_tier_e_integration() {
|
||||
// Use Opus64k (high bitrate ceiling + high payload bound) so Tiers
|
||||
// A/B/D never fire on the small bursts used here. Only Tier E.
|
||||
let mut meter = ConformanceMeter::with_token_bucket(TokenBucket::new(1_000, 500));
|
||||
let header = make_header(CodecId::Opus64k);
|
||||
let now = Instant::now();
|
||||
|
||||
// Two 500-byte (wire) packets = 1_000 bytes — exactly the bucket cap.
|
||||
assert!(
|
||||
meter
|
||||
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
|
||||
.is_ok()
|
||||
);
|
||||
assert!(
|
||||
meter
|
||||
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
// Third packet exceeds the 1_000-byte cap.
|
||||
let result = meter.observe(&header, 10, now);
|
||||
assert_eq!(result, Err(Violation::RateCapExceeded));
|
||||
}
|
||||
}
|
||||
@@ -25,16 +25,13 @@ pub struct Event {
|
||||
pub src: Option<String>,
|
||||
/// Packet sequence number.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seq: Option<u16>,
|
||||
pub seq: Option<u32>,
|
||||
/// Codec identifier.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codec: Option<String>,
|
||||
/// FEC block ID.
|
||||
/// FEC block ID (low byte) and symbol index (high byte).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub fec_block: Option<u8>,
|
||||
/// FEC symbol index.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub fec_sym: Option<u8>,
|
||||
pub fec_block: Option<u16>,
|
||||
/// Is FEC repair packet.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repair: Option<bool>,
|
||||
@@ -60,7 +57,9 @@ pub struct Event {
|
||||
|
||||
impl Event {
|
||||
fn now() -> String {
|
||||
chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.6fZ").to_string()
|
||||
chrono::Utc::now()
|
||||
.format("%Y-%m-%dT%H:%M:%S%.6fZ")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Create a minimal event with just type and timestamp.
|
||||
@@ -73,7 +72,6 @@ impl Event {
|
||||
seq: None,
|
||||
codec: None,
|
||||
fec_block: None,
|
||||
fec_sym: None,
|
||||
repair: None,
|
||||
len: None,
|
||||
to_count: None,
|
||||
@@ -85,33 +83,59 @@ impl Event {
|
||||
}
|
||||
|
||||
/// Set room.
|
||||
pub fn room(mut self, room: &str) -> Self { self.room = Some(room.to_string()); self }
|
||||
pub fn room(mut self, room: &str) -> Self {
|
||||
self.room = Some(room.to_string());
|
||||
self
|
||||
}
|
||||
/// Set source.
|
||||
pub fn src(mut self, src: &str) -> Self { self.src = Some(src.to_string()); self }
|
||||
pub fn src(mut self, src: &str) -> Self {
|
||||
self.src = Some(src.to_string());
|
||||
self
|
||||
}
|
||||
/// Set packet header fields from a MediaPacket.
|
||||
pub fn packet(mut self, pkt: &wzp_proto::MediaPacket) -> Self {
|
||||
self.seq = Some(pkt.header.seq);
|
||||
self.codec = Some(format!("{:?}", pkt.header.codec_id));
|
||||
self.fec_block = Some(pkt.header.fec_block);
|
||||
self.fec_sym = Some(pkt.header.fec_symbol);
|
||||
self.repair = Some(pkt.header.is_repair);
|
||||
self.repair = Some(pkt.header.is_repair());
|
||||
self.len = Some(pkt.payload.len());
|
||||
self
|
||||
}
|
||||
/// Set seq only (when full packet not available).
|
||||
pub fn seq(mut self, seq: u16) -> Self { self.seq = Some(seq); self }
|
||||
pub fn seq(mut self, seq: u32) -> Self {
|
||||
self.seq = Some(seq);
|
||||
self
|
||||
}
|
||||
/// Set payload length.
|
||||
pub fn len(mut self, len: usize) -> Self { self.len = Some(len); self }
|
||||
pub fn len(mut self, len: usize) -> Self {
|
||||
self.len = Some(len);
|
||||
self
|
||||
}
|
||||
/// Set recipient count.
|
||||
pub fn to_count(mut self, n: usize) -> Self { self.to_count = Some(n); self }
|
||||
pub fn to_count(mut self, n: usize) -> Self {
|
||||
self.to_count = Some(n);
|
||||
self
|
||||
}
|
||||
/// Set peer label.
|
||||
pub fn peer(mut self, peer: &str) -> Self { self.peer = Some(peer.to_string()); self }
|
||||
pub fn peer(mut self, peer: &str) -> Self {
|
||||
self.peer = Some(peer.to_string());
|
||||
self
|
||||
}
|
||||
/// Set drop reason.
|
||||
pub fn reason(mut self, reason: &str) -> Self { self.reason = Some(reason.to_string()); self }
|
||||
pub fn reason(mut self, reason: &str) -> Self {
|
||||
self.reason = Some(reason.to_string());
|
||||
self
|
||||
}
|
||||
/// Set presence action.
|
||||
pub fn action(mut self, action: &str) -> Self { self.action = Some(action.to_string()); self }
|
||||
pub fn action(mut self, action: &str) -> Self {
|
||||
self.action = Some(action.to_string());
|
||||
self
|
||||
}
|
||||
/// Set participant count.
|
||||
pub fn participants(mut self, n: usize) -> Self { self.participants = Some(n); self }
|
||||
pub fn participants(mut self, n: usize) -> Self {
|
||||
self.participants = Some(n);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle for emitting events. Cheap to clone.
|
||||
@@ -181,8 +205,12 @@ async fn writer_task(path: PathBuf, mut rx: mpsc::UnboundedReceiver<Event>) {
|
||||
while let Some(event) = rx.recv().await {
|
||||
match serde_json::to_string(&event) {
|
||||
Ok(json) => {
|
||||
if writer.write_all(json.as_bytes()).await.is_err() { break; }
|
||||
if writer.write_all(b"\n").await.is_err() { break; }
|
||||
if writer.write_all(json.as_bytes()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if writer.write_all(b"\n").await.is_err() {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
// Flush every 100 events
|
||||
if count % 100 == 0 {
|
||||
|
||||
@@ -11,11 +11,11 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
use sha2::{Sha256, Digest};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use wzp_proto::{MediaTransport, SignalMessage};
|
||||
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||
use wzp_transport::QuinnTransport;
|
||||
|
||||
use crate::config::{PeerConfig, TrustedConfig};
|
||||
@@ -56,13 +56,14 @@ impl Deduplicator {
|
||||
}
|
||||
|
||||
/// Returns true if this packet is a duplicate (already seen within TTL).
|
||||
fn is_dup(&mut self, room_hash: &[u8; 8], seq: u16, extra: u64) -> bool {
|
||||
fn is_dup(&mut self, room_hash: &[u8; 8], seq: u32, extra: u64) -> bool {
|
||||
let key = u64::from_be_bytes(*room_hash) ^ (seq as u64) ^ extra;
|
||||
let now = Instant::now();
|
||||
|
||||
// Periodic cleanup (every ~256 packets)
|
||||
if self.entries.len() > 256 {
|
||||
self.entries.retain(|_, ts| now.duration_since(*ts) < self.ttl);
|
||||
self.entries
|
||||
.retain(|_, ts| now.duration_since(*ts) < self.ttl);
|
||||
}
|
||||
|
||||
if let Some(ts) = self.entries.get(&key) {
|
||||
@@ -134,7 +135,7 @@ pub struct FederationManager {
|
||||
peers: Vec<PeerConfig>,
|
||||
trusted: Vec<TrustedConfig>,
|
||||
global_rooms: HashSet<String>,
|
||||
room_mgr: Arc<Mutex<RoomManager>>,
|
||||
room_mgr: Arc<RoomManager>,
|
||||
endpoint: quinn::Endpoint,
|
||||
local_tls_fp: String,
|
||||
metrics: Arc<crate::metrics::RelayMetrics>,
|
||||
@@ -161,7 +162,7 @@ impl FederationManager {
|
||||
peers: Vec<PeerConfig>,
|
||||
trusted: Vec<TrustedConfig>,
|
||||
global_rooms: HashSet<String>,
|
||||
room_mgr: Arc<Mutex<RoomManager>>,
|
||||
room_mgr: Arc<RoomManager>,
|
||||
endpoint: quinn::Endpoint,
|
||||
local_tls_fp: String,
|
||||
metrics: Arc<crate::metrics::RelayMetrics>,
|
||||
@@ -213,16 +214,22 @@ impl FederationManager {
|
||||
/// `origin_relay_fp` against its own fp and drops self-sourced
|
||||
/// forwards.
|
||||
pub async fn broadcast_signal(&self, msg: &wzp_proto::SignalMessage) -> usize {
|
||||
let links = self.peer_links.lock().await;
|
||||
let peers: Vec<(String, String, Arc<QuinnTransport>)> = {
|
||||
let links = self.peer_links.lock().await;
|
||||
links
|
||||
.iter()
|
||||
.map(|(fp, l)| (fp.clone(), l.label.clone(), l.transport.clone()))
|
||||
.collect()
|
||||
}; // lock released
|
||||
let mut count = 0;
|
||||
for (fp, link) in links.iter() {
|
||||
match link.transport.send_signal(msg).await {
|
||||
for (fp, label, transport) in &peers {
|
||||
match transport.send_signal(msg).await {
|
||||
Ok(()) => {
|
||||
count += 1;
|
||||
tracing::debug!(peer = %link.label, %fp, "federation: broadcast signal ok");
|
||||
tracing::debug!(peer = %label, %fp, "federation: broadcast signal ok");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(peer = %link.label, %fp, error = %e, "federation: broadcast signal failed");
|
||||
tracing::warn!(peer = %label, %fp, error = %e, "federation: broadcast signal failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,10 +250,12 @@ impl FederationManager {
|
||||
msg: &wzp_proto::SignalMessage,
|
||||
) -> Result<(), String> {
|
||||
let normalized = normalize_fp(peer_relay_fp);
|
||||
let links = self.peer_links.lock().await;
|
||||
match links.get(&normalized) {
|
||||
Some(link) => link
|
||||
.transport
|
||||
let transport = {
|
||||
let links = self.peer_links.lock().await;
|
||||
links.get(&normalized).map(|l| l.transport.clone())
|
||||
}; // lock released
|
||||
match transport {
|
||||
Some(t) => t
|
||||
.send_signal(msg)
|
||||
.await
|
||||
.map_err(|e| format!("send to peer {normalized}: {e}")),
|
||||
@@ -295,9 +304,10 @@ impl FederationManager {
|
||||
return Some(room.to_string());
|
||||
}
|
||||
// Hashed match (desktop clients hash room names for SNI privacy)
|
||||
self.global_rooms.iter().find(|name| {
|
||||
wzp_crypto::hash_room_name(name) == room
|
||||
}).map(|s| s.to_string())
|
||||
self.global_rooms
|
||||
.iter()
|
||||
.find(|name| wzp_crypto::hash_room_name(name) == room)
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Get the canonical federation room hash for a room.
|
||||
@@ -333,10 +343,7 @@ impl FederationManager {
|
||||
}
|
||||
|
||||
// Room event dispatcher
|
||||
let room_events = {
|
||||
let mgr = self.room_mgr.lock().await;
|
||||
mgr.subscribe_events()
|
||||
};
|
||||
let room_events = self.room_mgr.subscribe_events();
|
||||
let this = self.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
run_room_event_dispatcher(this, room_events).await;
|
||||
@@ -369,7 +376,10 @@ impl FederationManager {
|
||||
|
||||
/// Get all remote participants for a room from all peer links.
|
||||
/// Deduplicates by fingerprint (same participant may appear via multiple links).
|
||||
pub async fn get_remote_participants(&self, room: &str) -> Vec<wzp_proto::packet::RoomParticipant> {
|
||||
pub async fn get_remote_participants(
|
||||
&self,
|
||||
room: &str,
|
||||
) -> Vec<wzp_proto::packet::RoomParticipant> {
|
||||
let canonical = self.resolve_global_room(room);
|
||||
let links = self.peer_links.lock().await;
|
||||
let mut result = Vec::new();
|
||||
@@ -405,21 +415,35 @@ impl FederationManager {
|
||||
/// the other room-tagged helpers and for future per-room-name logging
|
||||
/// or rate limiting; the body currently forwards on `room_hash` alone
|
||||
/// because that's what the wire format carries.
|
||||
pub async fn forward_to_peers(&self, _room_name: &str, room_hash: &[u8; 8], media_data: &Bytes) {
|
||||
let links = self.peer_links.lock().await;
|
||||
if links.is_empty() {
|
||||
return;
|
||||
}
|
||||
for (_fp, link) in links.iter() {
|
||||
pub async fn forward_to_peers(
|
||||
&self,
|
||||
_room_name: &str,
|
||||
room_hash: &[u8; 8],
|
||||
media_data: &Bytes,
|
||||
) {
|
||||
let peers: Vec<(String, Arc<QuinnTransport>)> = {
|
||||
let links = self.peer_links.lock().await;
|
||||
if links.is_empty() {
|
||||
return;
|
||||
}
|
||||
links
|
||||
.values()
|
||||
.map(|l| (l.label.clone(), l.transport.clone()))
|
||||
.collect()
|
||||
}; // lock released
|
||||
|
||||
for (label, transport) in &peers {
|
||||
let mut tagged = Vec::with_capacity(8 + media_data.len());
|
||||
tagged.extend_from_slice(room_hash);
|
||||
tagged.extend_from_slice(media_data);
|
||||
match link.transport.send_raw_datagram(&tagged) {
|
||||
match transport.send_raw_datagram(&tagged) {
|
||||
Ok(()) => {
|
||||
self.metrics.federation_packets_forwarded
|
||||
.with_label_values(&[&link.label, "out"]).inc();
|
||||
self.metrics
|
||||
.federation_packets_forwarded
|
||||
.with_label_values(&[label, "out"])
|
||||
.inc();
|
||||
}
|
||||
Err(e) => warn!(peer = %link.label, "federation send error: {e}"),
|
||||
Err(e) => warn!(peer = %label, "federation send error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -427,20 +451,25 @@ impl FederationManager {
|
||||
// ── Trust verification (kept from previous implementation) ──
|
||||
|
||||
pub fn find_peer_by_fingerprint(&self, fp: &str) -> Option<&PeerConfig> {
|
||||
self.peers.iter().find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp))
|
||||
self.peers
|
||||
.iter()
|
||||
.find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp))
|
||||
}
|
||||
|
||||
pub fn find_peer_by_addr(&self, addr: SocketAddr) -> Option<&PeerConfig> {
|
||||
let addr_ip = addr.ip();
|
||||
self.peers.iter().find(|p| {
|
||||
p.url.parse::<SocketAddr>()
|
||||
p.url
|
||||
.parse::<SocketAddr>()
|
||||
.map(|sa| sa.ip() == addr_ip)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_trusted_by_fingerprint(&self, fp: &str) -> Option<&TrustedConfig> {
|
||||
self.trusted.iter().find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp))
|
||||
self.trusted
|
||||
.iter()
|
||||
.find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp))
|
||||
}
|
||||
|
||||
pub fn check_inbound_trust(&self, addr: SocketAddr, hello_fp: &str) -> Option<String> {
|
||||
@@ -448,7 +477,12 @@ impl FederationManager {
|
||||
return Some(peer.label.clone().unwrap_or_else(|| peer.url.clone()));
|
||||
}
|
||||
if let Some(trusted) = self.find_trusted_by_fingerprint(hello_fp) {
|
||||
return Some(trusted.label.clone().unwrap_or_else(|| hello_fp[..16].to_string()));
|
||||
return Some(
|
||||
trusted
|
||||
.label
|
||||
.clone()
|
||||
.unwrap_or_else(|| hello_fp[..16].to_string()),
|
||||
);
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -467,7 +501,8 @@ pub async fn run_federation_media_egress(
|
||||
if count == 1 || count % 250 == 0 {
|
||||
info!(room = %out.room_name, count, "federation egress: forwarding media");
|
||||
}
|
||||
fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data).await;
|
||||
fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data)
|
||||
.await;
|
||||
}
|
||||
info!(total = count, "federation egress task ended");
|
||||
}
|
||||
@@ -483,25 +518,35 @@ async fn run_room_event_dispatcher(
|
||||
match events.recv().await {
|
||||
Ok(RoomEvent::LocalJoin { room }) => {
|
||||
if fm.is_global_room(&room) {
|
||||
let participants = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
mgr.local_participant_list(&room)
|
||||
};
|
||||
let participants = fm.room_mgr.local_participant_list(&room);
|
||||
info!(room = %room, count = participants.len(), "global room now active, announcing to peers");
|
||||
let msg = SignalMessage::GlobalRoomActive { room, participants };
|
||||
let links = fm.peer_links.lock().await;
|
||||
for link in links.values() {
|
||||
let _ = link.transport.send_signal(&msg).await;
|
||||
let msg = SignalMessage::GlobalRoomActive {
|
||||
version: default_signal_version(),
|
||||
room,
|
||||
participants,
|
||||
};
|
||||
let transports: Vec<Arc<QuinnTransport>> = {
|
||||
let links = fm.peer_links.lock().await;
|
||||
links.values().map(|l| l.transport.clone()).collect()
|
||||
};
|
||||
for t in &transports {
|
||||
let _ = t.send_signal(&msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(RoomEvent::LocalLeave { room }) => {
|
||||
if fm.is_global_room(&room) {
|
||||
info!(room = %room, "global room now inactive, announcing to peers");
|
||||
let msg = SignalMessage::GlobalRoomInactive { room };
|
||||
let links = fm.peer_links.lock().await;
|
||||
for link in links.values() {
|
||||
let _ = link.transport.send_signal(&msg).await;
|
||||
let msg = SignalMessage::GlobalRoomInactive {
|
||||
version: default_signal_version(),
|
||||
room,
|
||||
};
|
||||
let transports: Vec<Arc<QuinnTransport>> = {
|
||||
let links = fm.peer_links.lock().await;
|
||||
links.values().map(|l| l.transport.clone()).collect()
|
||||
};
|
||||
for t in &transports {
|
||||
let _ = t.send_signal(&msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,7 +574,9 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
|
||||
let links = fm.peer_links.lock().await;
|
||||
let mut stale = Vec::new();
|
||||
for (fp, link) in links.iter() {
|
||||
if link.last_seen.elapsed() > stale_threshold && !link.remote_participants.is_empty() {
|
||||
if link.last_seen.elapsed() > stale_threshold
|
||||
&& !link.remote_participants.is_empty()
|
||||
{
|
||||
for room in link.remote_participants.keys() {
|
||||
stale.push((fp.clone(), room.clone()));
|
||||
}
|
||||
@@ -560,20 +607,20 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
|
||||
|
||||
// Broadcast updated RoomUpdate for affected rooms
|
||||
for room in &affected_rooms {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
for local_room in mgr.active_rooms() {
|
||||
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(room) {
|
||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
||||
let remote = fm.get_remote_participants(&local_room).await;
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
for local_room in &active {
|
||||
if fm.resolve_global_room(local_room) == fm.resolve_global_room(room) {
|
||||
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||
let remote = fm.get_remote_participants(local_room).await;
|
||||
all_participants.extend(remote);
|
||||
let mut seen = HashSet::new();
|
||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||
let update = SignalMessage::RoomUpdate {
|
||||
version: default_signal_version(),
|
||||
count: all_participants.len() as u32,
|
||||
participants: all_participants,
|
||||
};
|
||||
let senders = mgr.local_senders(&local_room);
|
||||
drop(mgr);
|
||||
let senders = fm.room_mgr.local_senders(local_room);
|
||||
room::broadcast_signal(&senders, &update).await;
|
||||
info!(room = %room, "swept stale presence — broadcast updated RoomUpdate");
|
||||
break;
|
||||
@@ -609,7 +656,10 @@ async fn run_peer_loop(fm: Arc<FederationManager>, peer: PeerConfig) {
|
||||
}
|
||||
|
||||
/// Connect to a peer relay and send hello.
|
||||
async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Arc<QuinnTransport>, anyhow::Error> {
|
||||
async fn connect_to_peer(
|
||||
fm: &FederationManager,
|
||||
peer: &PeerConfig,
|
||||
) -> Result<Arc<QuinnTransport>, anyhow::Error> {
|
||||
let addr: SocketAddr = peer.url.parse()?;
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(&fm.endpoint, addr, "_federation", client_cfg).await?;
|
||||
@@ -617,9 +667,12 @@ async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Ar
|
||||
|
||||
// Send hello with our TLS fingerprint
|
||||
let hello = SignalMessage::FederationHello {
|
||||
version: default_signal_version(),
|
||||
tls_fingerprint: fm.local_tls_fp.clone(),
|
||||
};
|
||||
transport.send_signal(&hello).await
|
||||
transport
|
||||
.send_signal(&hello)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("federation hello send failed: {e}"))?;
|
||||
|
||||
info!(peer_url = %peer.url, label = ?peer.label, "federation: connected (hello sent)");
|
||||
@@ -636,31 +689,40 @@ async fn run_federation_link(
|
||||
peer_label: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// Register peer link + metrics
|
||||
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(1);
|
||||
fm.metrics
|
||||
.federation_peer_status
|
||||
.with_label_values(&[&peer_label])
|
||||
.set(1);
|
||||
{
|
||||
let mut links = fm.peer_links.lock().await;
|
||||
links.insert(peer_fp.clone(), PeerLink {
|
||||
transport: transport.clone(),
|
||||
label: peer_label.clone(),
|
||||
active_rooms: HashSet::new(),
|
||||
remote_participants: HashMap::new(),
|
||||
last_seen: Instant::now(),
|
||||
});
|
||||
links.insert(
|
||||
peer_fp.clone(),
|
||||
PeerLink {
|
||||
transport: transport.clone(),
|
||||
label: peer_label.clone(),
|
||||
active_rooms: HashSet::new(),
|
||||
remote_participants: HashMap::new(),
|
||||
last_seen: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Announce our currently active global rooms to this new peer
|
||||
// Collect all announcements first, then send (avoid holding locks across await)
|
||||
let announcements = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
let active = mgr.active_rooms();
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
let mut msgs = Vec::new();
|
||||
|
||||
// Local rooms
|
||||
for room_name in &active {
|
||||
if fm.is_global_room(room_name) {
|
||||
let participants = mgr.local_participant_list(room_name);
|
||||
let participants = fm.room_mgr.local_participant_list(room_name);
|
||||
info!(peer = %peer_label, room = %room_name, participants = participants.len(), "announcing local global room to new peer");
|
||||
msgs.push(SignalMessage::GlobalRoomActive { room: room_name.clone(), participants });
|
||||
msgs.push(SignalMessage::GlobalRoomActive {
|
||||
version: default_signal_version(),
|
||||
room: room_name.clone(),
|
||||
participants,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -672,6 +734,7 @@ async fn run_federation_link(
|
||||
if fm.is_global_room(room) {
|
||||
info!(peer = %peer_label, room = %room, via = %link.label, "propagating remote room to new peer");
|
||||
msgs.push(SignalMessage::GlobalRoomActive {
|
||||
version: default_signal_version(),
|
||||
room: room.clone(),
|
||||
participants: participants.clone(),
|
||||
});
|
||||
@@ -756,7 +819,10 @@ async fn run_federation_link(
|
||||
}
|
||||
|
||||
// Cleanup: remove peer link + metrics
|
||||
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(0);
|
||||
fm.metrics
|
||||
.federation_peer_status
|
||||
.with_label_values(&[&peer_label])
|
||||
.set(0);
|
||||
{
|
||||
let mut links = fm.peer_links.lock().await;
|
||||
links.remove(&peer_fp);
|
||||
@@ -782,7 +848,9 @@ async fn handle_signal(
|
||||
}
|
||||
|
||||
match msg {
|
||||
SignalMessage::GlobalRoomActive { room, participants } => {
|
||||
SignalMessage::GlobalRoomActive {
|
||||
room, participants, ..
|
||||
} => {
|
||||
if fm.is_global_room(&room) {
|
||||
info!(peer = %peer_label, room = %room, remote_participants = participants.len(), "peer has global room active");
|
||||
let mut links = fm.peer_links.lock().await;
|
||||
@@ -794,56 +862,74 @@ async fn handle_signal(
|
||||
fm.metrics.federation_active_rooms.set(total as i64);
|
||||
if let Some(link) = links.get_mut(peer_fp) {
|
||||
// Tag remote participants with their relay label
|
||||
let tagged: Vec<_> = participants.iter().map(|p| {
|
||||
let mut tagged = p.clone();
|
||||
if tagged.relay_label.is_none() {
|
||||
tagged.relay_label = Some(link.label.clone());
|
||||
}
|
||||
tagged
|
||||
}).collect();
|
||||
let tagged: Vec<_> = participants
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let mut tagged = p.clone();
|
||||
if tagged.relay_label.is_none() {
|
||||
tagged.relay_label = Some(link.label.clone());
|
||||
}
|
||||
tagged
|
||||
})
|
||||
.collect();
|
||||
link.remote_participants.insert(room.clone(), tagged);
|
||||
}
|
||||
// Propagate to other peers (with relay labels preserved)
|
||||
let tagged_for_propagation = if let Some(link) = links.get(peer_fp) {
|
||||
let label = link.label.clone();
|
||||
participants.iter().map(|p| {
|
||||
let mut t = p.clone();
|
||||
if t.relay_label.is_none() {
|
||||
t.relay_label = Some(label.clone());
|
||||
}
|
||||
t
|
||||
}).collect::<Vec<_>>()
|
||||
participants
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let mut t = p.clone();
|
||||
if t.relay_label.is_none() {
|
||||
t.relay_label = Some(label.clone());
|
||||
}
|
||||
t
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
participants.clone()
|
||||
};
|
||||
for (fp, link) in links.iter() {
|
||||
if fp != peer_fp {
|
||||
let _ = link.transport.send_signal(&SignalMessage::GlobalRoomActive {
|
||||
room: room.clone(),
|
||||
participants: tagged_for_propagation.clone(),
|
||||
}).await;
|
||||
let _ = link
|
||||
.transport
|
||||
.send_signal(&SignalMessage::GlobalRoomActive {
|
||||
version: default_signal_version(),
|
||||
room: room.clone(),
|
||||
participants: tagged_for_propagation.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
drop(links);
|
||||
|
||||
// Broadcast updated RoomUpdate to local clients in this room
|
||||
// Find the local room name (may be hashed or raw)
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
for local_room in mgr.active_rooms() {
|
||||
if fm.is_global_room(&local_room) && fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
for local_room in &active {
|
||||
if fm.is_global_room(local_room)
|
||||
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
|
||||
{
|
||||
// Build merged participant list: local + all remote (deduped)
|
||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
||||
let links = fm.peer_links.lock().await;
|
||||
for link in links.values() {
|
||||
if let Some(ref canonical) = fm.resolve_global_room(&local_room) {
|
||||
if let Some(remote) = link.remote_participants.get(canonical.as_str()) {
|
||||
all_participants.extend(remote.iter().cloned());
|
||||
}
|
||||
// Also check raw room name, but only if different from canonical
|
||||
if canonical != &local_room {
|
||||
if let Some(remote) = link.remote_participants.get(&local_room) {
|
||||
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||
{
|
||||
let links = fm.peer_links.lock().await;
|
||||
for link in links.values() {
|
||||
if let Some(ref canonical) = fm.resolve_global_room(local_room) {
|
||||
if let Some(remote) =
|
||||
link.remote_participants.get(canonical.as_str())
|
||||
{
|
||||
all_participants.extend(remote.iter().cloned());
|
||||
}
|
||||
// Also check raw room name, but only if different from canonical
|
||||
if canonical != local_room {
|
||||
if let Some(remote) =
|
||||
link.remote_participants.get(local_room)
|
||||
{
|
||||
all_participants.extend(remote.iter().cloned());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -851,19 +937,18 @@ async fn handle_signal(
|
||||
let mut seen = HashSet::new();
|
||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||
let update = SignalMessage::RoomUpdate {
|
||||
version: default_signal_version(),
|
||||
count: all_participants.len() as u32,
|
||||
participants: all_participants,
|
||||
};
|
||||
let senders = mgr.local_senders(&local_room);
|
||||
drop(links);
|
||||
drop(mgr);
|
||||
let senders = fm.room_mgr.local_senders(local_room);
|
||||
room::broadcast_signal(&senders, &update).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
SignalMessage::GlobalRoomInactive { room } => {
|
||||
SignalMessage::GlobalRoomInactive { room, .. } => {
|
||||
info!(peer = %peer_label, room = %room, "peer global room now inactive");
|
||||
let mut links = fm.peer_links.lock().await;
|
||||
if let Some(link) = links.get_mut(peer_fp) {
|
||||
@@ -885,7 +970,9 @@ async fn handle_signal(
|
||||
let canonical = fm.resolve_global_room(&room);
|
||||
let mut result = Vec::new();
|
||||
for (fp, link) in links.iter() {
|
||||
if fp == peer_fp { continue; }
|
||||
if fp == peer_fp {
|
||||
continue;
|
||||
}
|
||||
if let Some(ref c) = canonical {
|
||||
if let Some(remote) = link.remote_participants.get(c.as_str()) {
|
||||
result.extend(remote.iter().cloned());
|
||||
@@ -899,14 +986,16 @@ async fn handle_signal(
|
||||
|
||||
// Propagate to other peers: send updated GlobalRoomActive with revised list,
|
||||
// or GlobalRoomInactive if no participants remain anywhere
|
||||
let local_active = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
mgr.active_rooms().iter().any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room))
|
||||
};
|
||||
let local_active = fm
|
||||
.room_mgr
|
||||
.active_rooms()
|
||||
.iter()
|
||||
.any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room));
|
||||
let has_remaining = !remaining_remote.is_empty() || local_active;
|
||||
|
||||
// Collect peer transports to send to (avoid holding lock across await)
|
||||
let peer_sends: Vec<_> = links.iter()
|
||||
let peer_sends: Vec<_> = links
|
||||
.iter()
|
||||
.filter(|(fp, _)| *fp != peer_fp)
|
||||
.map(|(_, link)| link.transport.clone())
|
||||
.collect();
|
||||
@@ -916,15 +1005,16 @@ async fn handle_signal(
|
||||
// Send updated participant list to other peers
|
||||
let mut updated_participants = remaining_remote.clone();
|
||||
if local_active {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
for local_room in mgr.active_rooms() {
|
||||
for local_room in fm.room_mgr.active_rooms() {
|
||||
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
||||
updated_participants.extend(mgr.local_participant_list(&local_room));
|
||||
updated_participants
|
||||
.extend(fm.room_mgr.local_participant_list(&local_room));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let msg = SignalMessage::GlobalRoomActive {
|
||||
version: default_signal_version(),
|
||||
room: room.clone(),
|
||||
participants: updated_participants,
|
||||
};
|
||||
@@ -933,27 +1023,32 @@ async fn handle_signal(
|
||||
}
|
||||
} else {
|
||||
// No participants left anywhere — propagate inactive
|
||||
let msg = SignalMessage::GlobalRoomInactive { room: room.clone() };
|
||||
let msg = SignalMessage::GlobalRoomInactive {
|
||||
version: default_signal_version(),
|
||||
room: room.clone(),
|
||||
};
|
||||
for transport in &peer_sends {
|
||||
let _ = transport.send_signal(&msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast updated RoomUpdate to local clients (remote participant removed)
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
for local_room in mgr.active_rooms() {
|
||||
if fm.is_global_room(&local_room) && fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
for local_room in &active {
|
||||
if fm.is_global_room(local_room)
|
||||
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
|
||||
{
|
||||
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||
all_participants.extend(remaining_remote.iter().cloned());
|
||||
// Deduplicate by fingerprint
|
||||
let mut seen = HashSet::new();
|
||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||
let update = SignalMessage::RoomUpdate {
|
||||
version: default_signal_version(),
|
||||
count: all_participants.len() as u32,
|
||||
participants: all_participants,
|
||||
};
|
||||
let senders = mgr.local_senders(&local_room);
|
||||
drop(mgr);
|
||||
let senders = fm.room_mgr.local_senders(local_room);
|
||||
room::broadcast_signal(&senders, &update).await;
|
||||
info!(room = %room, "broadcast updated presence (remote participant removed)");
|
||||
break;
|
||||
@@ -972,7 +1067,11 @@ async fn handle_signal(
|
||||
// Loop prevention: drop any forward whose origin matches
|
||||
// our own federation TLS fingerprint. With
|
||||
// broadcast-to-all-peers this prevents A→B→A echo loops.
|
||||
SignalMessage::FederatedSignalForward { inner, origin_relay_fp } => {
|
||||
SignalMessage::FederatedSignalForward {
|
||||
inner,
|
||||
origin_relay_fp,
|
||||
..
|
||||
} => {
|
||||
if origin_relay_fp == fm.local_tls_fp {
|
||||
tracing::debug!(
|
||||
peer = %peer_label,
|
||||
@@ -1016,12 +1115,10 @@ async fn handle_signal(
|
||||
}
|
||||
|
||||
/// Handle an incoming federation datagram (room-hash-tagged media).
|
||||
async fn handle_datagram(
|
||||
fm: &Arc<FederationManager>,
|
||||
source_peer_fp: &str,
|
||||
data: Bytes,
|
||||
) {
|
||||
if data.len() < 12 { return; } // 8-byte hash + min packet
|
||||
async fn handle_datagram(fm: &Arc<FederationManager>, source_peer_fp: &str, data: Bytes) {
|
||||
if data.len() < 12 {
|
||||
return;
|
||||
} // 8-byte hash + min packet
|
||||
|
||||
let mut rh = [0u8; 8];
|
||||
rh.copy_from_slice(&data[..8]);
|
||||
@@ -1030,7 +1127,8 @@ async fn handle_datagram(
|
||||
let pkt = match wzp_proto::MediaPacket::from_bytes(media_bytes.clone()) {
|
||||
Some(pkt) => pkt,
|
||||
None => {
|
||||
fm.event_log.emit(Event::new("federation_ingress_malformed").len(data.len()));
|
||||
fm.event_log
|
||||
.emit(Event::new("federation_ingress_malformed").len(data.len()));
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -1038,13 +1136,22 @@ async fn handle_datagram(
|
||||
// Event log: federation ingress
|
||||
let peer_label = {
|
||||
let links = fm.peer_links.lock().await;
|
||||
links.get(source_peer_fp).map(|l| l.label.clone()).unwrap_or_default()
|
||||
links
|
||||
.get(source_peer_fp)
|
||||
.map(|l| l.label.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
fm.event_log.emit(Event::new("federation_ingress").packet(&pkt).peer(&peer_label));
|
||||
fm.event_log.emit(
|
||||
Event::new("federation_ingress")
|
||||
.packet(&pkt)
|
||||
.peer(&peer_label),
|
||||
);
|
||||
|
||||
// Count inbound federation packet + update last_seen
|
||||
fm.metrics.federation_packets_forwarded
|
||||
.with_label_values(&[source_peer_fp, "in"]).inc();
|
||||
fm.metrics
|
||||
.federation_packets_forwarded
|
||||
.with_label_values(&[source_peer_fp, "in"])
|
||||
.inc();
|
||||
{
|
||||
let mut links = fm.peer_links.lock().await;
|
||||
if let Some(link) = links.get_mut(source_peer_fp) {
|
||||
@@ -1065,38 +1172,53 @@ async fn handle_datagram(
|
||||
{
|
||||
let mut dedup = fm.dedup.lock().await;
|
||||
if dedup.is_dup(&rh, pkt.header.seq, payload_hash) {
|
||||
fm.event_log.emit(Event::new("dedup_drop").seq(pkt.header.seq).peer(&peer_label));
|
||||
fm.event_log.emit(
|
||||
Event::new("dedup_drop")
|
||||
.seq(pkt.header.seq)
|
||||
.peer(&peer_label),
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Find room by hash — check local rooms AND global room config
|
||||
// Find room by hash -- check local rooms AND global room config
|
||||
let room_name = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
let active = mgr.active_rooms();
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
// First: check local rooms (has participants)
|
||||
active.iter().find(|r| room_hash(r) == rh).cloned()
|
||||
.or_else(|| active.iter().find(|r| fm.global_room_hash(r) == rh).cloned())
|
||||
active
|
||||
.iter()
|
||||
.find(|r| room_hash(r) == rh)
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
active
|
||||
.iter()
|
||||
.find(|r| fm.global_room_hash(r) == rh)
|
||||
.cloned()
|
||||
})
|
||||
// Second: check static global room config (hub relay may have no local participants)
|
||||
.or_else(|| {
|
||||
fm.global_rooms.iter().find(|name| room_hash(name) == rh).cloned()
|
||||
fm.global_rooms
|
||||
.iter()
|
||||
.find(|name| room_hash(name) == rh)
|
||||
.cloned()
|
||||
})
|
||||
};
|
||||
|
||||
let room_name = match room_name {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
fm.event_log.emit(Event::new("room_not_found").seq(pkt.header.seq).peer(&peer_label));
|
||||
fm.event_log.emit(
|
||||
Event::new("room_not_found")
|
||||
.seq(pkt.header.seq)
|
||||
.peer(&peer_label),
|
||||
);
|
||||
// Phase 4.1 diagnostic: log the hash + active rooms
|
||||
// so we can diagnose cross-relay call-* media routing
|
||||
// failures. This fires when a peer relay sends media
|
||||
// for a room we don't have locally — could be a
|
||||
// timing issue (peer joined before us) or a hash
|
||||
// mismatch.
|
||||
let active = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
mgr.active_rooms()
|
||||
};
|
||||
let active = fm.room_mgr.active_rooms();
|
||||
warn!(
|
||||
room_hash = ?rh,
|
||||
active_rooms = ?active,
|
||||
@@ -1111,32 +1233,46 @@ async fn handle_datagram(
|
||||
// Rate limit per room
|
||||
if FEDERATION_RATE_LIMIT_PPS > 0 {
|
||||
let mut limiters = fm.rate_limiters.lock().await;
|
||||
let limiter = limiters.entry(room_name.clone())
|
||||
let limiter = limiters
|
||||
.entry(room_name.clone())
|
||||
.or_insert_with(|| RateLimiter::new(FEDERATION_RATE_LIMIT_PPS));
|
||||
if !limiter.allow() {
|
||||
fm.event_log.emit(Event::new("rate_limit_drop").room(&room_name).seq(pkt.header.seq));
|
||||
fm.event_log.emit(
|
||||
Event::new("rate_limit_drop")
|
||||
.room(&room_name)
|
||||
.seq(pkt.header.seq),
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver to all local participants — forward the raw bytes as-is.
|
||||
// The original sender's MediaPacket is preserved exactly (no re-serialization).
|
||||
let locals = {
|
||||
let mgr = fm.room_mgr.lock().await;
|
||||
mgr.local_senders(&room_name)
|
||||
};
|
||||
let locals = fm.room_mgr.local_senders(&room_name);
|
||||
for sender in &locals {
|
||||
match sender {
|
||||
room::ParticipantSender::Quic(t) => {
|
||||
if let Err(e) = t.send_raw_datagram(&media_bytes) {
|
||||
fm.event_log.emit(Event::new("local_deliver_error").room(&room_name).seq(pkt.header.seq).reason(&e.to_string()));
|
||||
fm.event_log.emit(
|
||||
Event::new("local_deliver_error")
|
||||
.room(&room_name)
|
||||
.seq(pkt.header.seq)
|
||||
.reason(&e.to_string()),
|
||||
);
|
||||
warn!("federation local delivery error: {e}");
|
||||
}
|
||||
}
|
||||
room::ParticipantSender::WebSocket(_) => { let _ = sender.send_raw(&pkt.payload).await; }
|
||||
room::ParticipantSender::WebSocket(_) => {
|
||||
let _ = sender.send_raw(&pkt.payload).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
fm.event_log.emit(Event::new("local_deliver").room(&room_name).seq(pkt.header.seq).to_count(locals.len()));
|
||||
fm.event_log.emit(
|
||||
Event::new("local_deliver")
|
||||
.room(&room_name)
|
||||
.seq(pkt.header.seq)
|
||||
.to_count(locals.len()),
|
||||
);
|
||||
|
||||
// Multi-hop: forward to ALL other connected peers (not the source)
|
||||
// Don't filter by active_rooms — the receiving peer decides whether to deliver
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! recv `CallOffer` → verify → generate ephemeral → derive session → send `CallAnswer`.
|
||||
|
||||
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage, default_signal_version};
|
||||
|
||||
/// Accept the relay (callee) side of the cryptographic handshake.
|
||||
///
|
||||
@@ -20,29 +20,71 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
pub async fn accept_handshake(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
) -> Result<(Box<dyn CryptoSession>, QualityProfile, String, Option<String>), anyhow::Error> {
|
||||
) -> Result<
|
||||
(
|
||||
Box<dyn CryptoSession>,
|
||||
QualityProfile,
|
||||
String,
|
||||
Option<String>,
|
||||
),
|
||||
anyhow::Error,
|
||||
> {
|
||||
// 1. Receive CallOffer
|
||||
let offer = transport
|
||||
.recv_signal()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?;
|
||||
|
||||
let (caller_identity_pub, caller_ephemeral_pub, caller_signature, supported_profiles, caller_alias) =
|
||||
match offer {
|
||||
SignalMessage::CallOffer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles,
|
||||
alias,
|
||||
} => (identity_pub, ephemeral_pub, signature, supported_profiles, alias),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallOffer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
))
|
||||
}
|
||||
let (
|
||||
caller_identity_pub,
|
||||
caller_ephemeral_pub,
|
||||
caller_signature,
|
||||
supported_profiles,
|
||||
caller_alias,
|
||||
protocol_version,
|
||||
caller_video_codecs,
|
||||
) = match offer {
|
||||
SignalMessage::CallOffer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles,
|
||||
alias,
|
||||
protocol_version,
|
||||
supported_versions: _,
|
||||
video_codecs,
|
||||
..
|
||||
} => (
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles,
|
||||
alias,
|
||||
protocol_version,
|
||||
video_codecs,
|
||||
),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallOffer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// 1a. Protocol version check — we only speak v2.
|
||||
if protocol_version != 2 {
|
||||
let mismatch = SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
|
||||
server_supported: vec![2],
|
||||
},
|
||||
call_id: None,
|
||||
};
|
||||
let _ = transport.send_signal(&mismatch).await;
|
||||
return Err(anyhow::anyhow!(
|
||||
"protocol version mismatch: client requested {protocol_version}, server supports [2]"
|
||||
));
|
||||
}
|
||||
|
||||
// 2. Verify caller's signature over (ephemeral_pub || "call-offer")
|
||||
let mut verify_data = Vec::with_capacity(32 + 10);
|
||||
@@ -69,23 +111,28 @@ pub async fn accept_handshake(
|
||||
// Choose the best supported profile (prefer GOOD > DEGRADED > CATASTROPHIC)
|
||||
let chosen_profile = choose_profile(&supported_profiles);
|
||||
|
||||
// Pick the first video codec the caller supports (relay forwards all video).
|
||||
let video_codec = caller_video_codecs.into_iter().next();
|
||||
|
||||
// 6. Send CallAnswer
|
||||
let answer = SignalMessage::CallAnswer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
video_codec,
|
||||
};
|
||||
transport.send_signal(&answer).await?;
|
||||
|
||||
// Derive caller fingerprint: SHA-256(Ed25519 pub)[:16], formatted as xxxx:xxxx:...
|
||||
// Must match the format used in signal registration and presence.
|
||||
let caller_fp = {
|
||||
use sha2::{Sha256, Digest};
|
||||
use sha2::{Digest, Sha256};
|
||||
let hash = Sha256::digest(&caller_identity_pub);
|
||||
let fp = wzp_crypto::Fingerprint([
|
||||
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7],
|
||||
hash[8], hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15],
|
||||
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7], hash[8],
|
||||
hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15],
|
||||
]);
|
||||
fp.to_string()
|
||||
};
|
||||
@@ -107,6 +154,7 @@ fn choose_profile(_supported: &[QualityProfile]) -> QualityProfile {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::CodecId;
|
||||
|
||||
#[test]
|
||||
fn choose_profile_picks_highest_bitrate() {
|
||||
@@ -124,4 +172,35 @@ mod tests {
|
||||
let chosen = choose_profile(&[]);
|
||||
assert_eq!(chosen, QualityProfile::GOOD);
|
||||
}
|
||||
|
||||
// ── Video codec negotiation ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn video_codec_picks_first_offered() {
|
||||
let codecs = vec![CodecId::H264Baseline];
|
||||
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||
assert_eq!(chosen, Some(CodecId::H264Baseline));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn video_codec_none_when_no_codecs_offered() {
|
||||
let codecs: Vec<CodecId> = vec![];
|
||||
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||
assert_eq!(chosen, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn video_codec_single_codec_is_selected() {
|
||||
let codecs = vec![CodecId::H265Main];
|
||||
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||
assert_eq!(chosen, Some(CodecId::H265Main));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn video_codec_order_is_preserved() {
|
||||
// The relay must pick the FIRST codec as-offered, not sort or re-rank.
|
||||
let codecs = vec![CodecId::H264Baseline, CodecId::Av1Main];
|
||||
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||
assert_eq!(chosen, Some(CodecId::H264Baseline));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,22 +7,27 @@
|
||||
//! It operates on FEC-protected packets, managing loss recovery and adaptive
|
||||
//! quality transitions.
|
||||
|
||||
pub mod audio_scorer;
|
||||
pub mod auth;
|
||||
pub mod call_registry;
|
||||
pub mod config;
|
||||
pub mod conformance;
|
||||
pub mod event_log;
|
||||
pub mod federation;
|
||||
pub mod signal_hub;
|
||||
pub mod handshake;
|
||||
pub mod metrics;
|
||||
pub mod pipeline;
|
||||
pub mod presence;
|
||||
pub mod probe;
|
||||
pub mod relay_link;
|
||||
pub mod response_policy;
|
||||
pub mod room;
|
||||
pub mod route;
|
||||
pub mod session_mgr;
|
||||
pub mod signal_hub;
|
||||
pub mod trunk;
|
||||
pub mod verdict;
|
||||
pub mod video_scorer;
|
||||
pub mod ws;
|
||||
|
||||
pub use config::RelayConfig;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,14 @@
|
||||
//! Prometheus metrics for the WZP relay daemon.
|
||||
|
||||
use prometheus::{
|
||||
Encoder, GaugeVec, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
|
||||
Opts, Registry, TextEncoder,
|
||||
Encoder, GaugeVec, Histogram, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge,
|
||||
IntGaugeVec, Opts, Registry, TextEncoder,
|
||||
};
|
||||
use wzp_proto::packet::QualityReport;
|
||||
use std::sync::Arc;
|
||||
use wzp_proto::MediaHeader;
|
||||
use wzp_proto::packet::QualityReport;
|
||||
|
||||
use crate::conformance::Violation;
|
||||
|
||||
/// All relay-level Prometheus metrics.
|
||||
#[derive(Clone)]
|
||||
@@ -32,6 +35,9 @@ pub struct RelayMetrics {
|
||||
// Phase 4: loss-recovery breakdown per session.
|
||||
pub session_dred_reconstructions: IntCounterVec,
|
||||
pub session_classical_plc: IntCounterVec,
|
||||
pub conformance_violations: IntCounterVec,
|
||||
pub conformance_bytes: HistogramVec,
|
||||
pub conformance_iat_ms: HistogramVec,
|
||||
registry: Registry,
|
||||
}
|
||||
|
||||
@@ -40,21 +46,23 @@ impl RelayMetrics {
|
||||
pub fn new() -> Self {
|
||||
let registry = Registry::new();
|
||||
|
||||
let active_sessions = IntGauge::with_opts(
|
||||
Opts::new("wzp_relay_active_sessions", "Current active sessions"),
|
||||
)
|
||||
let active_sessions = IntGauge::with_opts(Opts::new(
|
||||
"wzp_relay_active_sessions",
|
||||
"Current active sessions",
|
||||
))
|
||||
.expect("metric");
|
||||
let active_rooms = IntGauge::with_opts(
|
||||
Opts::new("wzp_relay_active_rooms", "Current active rooms"),
|
||||
)
|
||||
let active_rooms =
|
||||
IntGauge::with_opts(Opts::new("wzp_relay_active_rooms", "Current active rooms"))
|
||||
.expect("metric");
|
||||
let packets_forwarded = IntCounter::with_opts(Opts::new(
|
||||
"wzp_relay_packets_forwarded_total",
|
||||
"Total packets forwarded",
|
||||
))
|
||||
.expect("metric");
|
||||
let packets_forwarded = IntCounter::with_opts(
|
||||
Opts::new("wzp_relay_packets_forwarded_total", "Total packets forwarded"),
|
||||
)
|
||||
.expect("metric");
|
||||
let bytes_forwarded = IntCounter::with_opts(
|
||||
Opts::new("wzp_relay_bytes_forwarded_total", "Total bytes forwarded"),
|
||||
)
|
||||
let bytes_forwarded = IntCounter::with_opts(Opts::new(
|
||||
"wzp_relay_bytes_forwarded_total",
|
||||
"Total bytes forwarded",
|
||||
))
|
||||
.expect("metric");
|
||||
let auth_attempts = IntCounterVec::new(
|
||||
Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"),
|
||||
@@ -66,31 +74,51 @@ impl RelayMetrics {
|
||||
"wzp_relay_handshake_duration_seconds",
|
||||
"Crypto handshake time",
|
||||
)
|
||||
.buckets(vec![0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]),
|
||||
.buckets(vec![
|
||||
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5,
|
||||
]),
|
||||
)
|
||||
.expect("metric");
|
||||
|
||||
let federation_peer_status = IntGaugeVec::new(
|
||||
Opts::new("wzp_federation_peer_status", "Peer connection status (0=disconnected, 1=connected)"),
|
||||
Opts::new(
|
||||
"wzp_federation_peer_status",
|
||||
"Peer connection status (0=disconnected, 1=connected)",
|
||||
),
|
||||
&["peer"],
|
||||
).expect("metric");
|
||||
)
|
||||
.expect("metric");
|
||||
let federation_peer_rtt_ms = GaugeVec::new(
|
||||
Opts::new("wzp_federation_peer_rtt_ms", "QUIC RTT to federated peer in milliseconds"),
|
||||
Opts::new(
|
||||
"wzp_federation_peer_rtt_ms",
|
||||
"QUIC RTT to federated peer in milliseconds",
|
||||
),
|
||||
&["peer"],
|
||||
).expect("metric");
|
||||
)
|
||||
.expect("metric");
|
||||
let federation_packets_forwarded = IntCounterVec::new(
|
||||
Opts::new("wzp_federation_packets_forwarded_total", "Packets forwarded to/from federated peers"),
|
||||
Opts::new(
|
||||
"wzp_federation_packets_forwarded_total",
|
||||
"Packets forwarded to/from federated peers",
|
||||
),
|
||||
&["peer", "direction"],
|
||||
).expect("metric");
|
||||
let federation_packets_deduped = IntCounter::with_opts(
|
||||
Opts::new("wzp_federation_packets_deduped_total", "Duplicate federation packets dropped"),
|
||||
).expect("metric");
|
||||
let federation_packets_rate_limited = IntCounter::with_opts(
|
||||
Opts::new("wzp_federation_packets_rate_limited_total", "Federation packets dropped by rate limiter"),
|
||||
).expect("metric");
|
||||
let federation_active_rooms = IntGauge::with_opts(
|
||||
Opts::new("wzp_federation_active_rooms", "Number of federated rooms currently active"),
|
||||
).expect("metric");
|
||||
)
|
||||
.expect("metric");
|
||||
let federation_packets_deduped = IntCounter::with_opts(Opts::new(
|
||||
"wzp_federation_packets_deduped_total",
|
||||
"Duplicate federation packets dropped",
|
||||
))
|
||||
.expect("metric");
|
||||
let federation_packets_rate_limited = IntCounter::with_opts(Opts::new(
|
||||
"wzp_federation_packets_rate_limited_total",
|
||||
"Federation packets dropped by rate limiter",
|
||||
))
|
||||
.expect("metric");
|
||||
let federation_active_rooms = IntGauge::with_opts(Opts::new(
|
||||
"wzp_federation_active_rooms",
|
||||
"Number of federated rooms currently active",
|
||||
))
|
||||
.expect("metric");
|
||||
|
||||
let session_buffer_depth = IntGaugeVec::new(
|
||||
Opts::new(
|
||||
@@ -109,10 +137,7 @@ impl RelayMetrics {
|
||||
)
|
||||
.expect("metric");
|
||||
let session_rtt_ms = GaugeVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_rtt_ms",
|
||||
"Round-trip time per session",
|
||||
),
|
||||
Opts::new("wzp_relay_session_rtt_ms", "Round-trip time per session"),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
@@ -149,26 +174,104 @@ impl RelayMetrics {
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
let conformance_violations = IntCounterVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_conformance_violations_total",
|
||||
"Conformance violations by tier, codec, media type and verdict",
|
||||
),
|
||||
&["tier", "codec_id", "media_type", "verdict"],
|
||||
)
|
||||
.expect("metric");
|
||||
let conformance_bytes = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"wzp_relay_conformance_bytes_per_session",
|
||||
"Packet size distribution observed by the conformance meter",
|
||||
)
|
||||
.buckets(vec![
|
||||
16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0,
|
||||
32768.0, 65536.0,
|
||||
]),
|
||||
&["media_type"],
|
||||
)
|
||||
.expect("metric");
|
||||
let conformance_iat_ms = HistogramVec::new(
|
||||
HistogramOpts::new(
|
||||
"wzp_relay_conformance_iat_ms",
|
||||
"Inter-arrival time distribution in milliseconds",
|
||||
)
|
||||
.buckets(vec![
|
||||
1.0, 5.0, 10.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 150.0, 200.0, 300.0, 500.0,
|
||||
]),
|
||||
&["media_type"],
|
||||
)
|
||||
.expect("metric");
|
||||
|
||||
registry.register(Box::new(active_sessions.clone())).expect("register");
|
||||
registry.register(Box::new(active_rooms.clone())).expect("register");
|
||||
registry.register(Box::new(packets_forwarded.clone())).expect("register");
|
||||
registry.register(Box::new(bytes_forwarded.clone())).expect("register");
|
||||
registry.register(Box::new(auth_attempts.clone())).expect("register");
|
||||
registry.register(Box::new(handshake_duration.clone())).expect("register");
|
||||
registry.register(Box::new(federation_peer_status.clone())).expect("register");
|
||||
registry.register(Box::new(federation_peer_rtt_ms.clone())).expect("register");
|
||||
registry.register(Box::new(federation_packets_forwarded.clone())).expect("register");
|
||||
registry.register(Box::new(federation_packets_deduped.clone())).expect("register");
|
||||
registry.register(Box::new(federation_packets_rate_limited.clone())).expect("register");
|
||||
registry.register(Box::new(federation_active_rooms.clone())).expect("register");
|
||||
registry.register(Box::new(session_buffer_depth.clone())).expect("register");
|
||||
registry.register(Box::new(session_loss_pct.clone())).expect("register");
|
||||
registry.register(Box::new(session_rtt_ms.clone())).expect("register");
|
||||
registry.register(Box::new(session_underruns.clone())).expect("register");
|
||||
registry.register(Box::new(session_overruns.clone())).expect("register");
|
||||
registry.register(Box::new(session_dred_reconstructions.clone())).expect("register");
|
||||
registry.register(Box::new(session_classical_plc.clone())).expect("register");
|
||||
registry
|
||||
.register(Box::new(active_sessions.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(active_rooms.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(packets_forwarded.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(bytes_forwarded.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(auth_attempts.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(handshake_duration.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_peer_status.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_peer_rtt_ms.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_packets_forwarded.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_packets_deduped.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_packets_rate_limited.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(federation_active_rooms.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_buffer_depth.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_loss_pct.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_rtt_ms.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_underruns.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_overruns.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_dred_reconstructions.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(session_classical_plc.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(conformance_violations.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(conformance_bytes.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(conformance_iat_ms.clone()))
|
||||
.expect("register");
|
||||
|
||||
Self {
|
||||
active_sessions,
|
||||
@@ -190,6 +293,9 @@ impl RelayMetrics {
|
||||
session_overruns,
|
||||
session_dred_reconstructions,
|
||||
session_classical_plc,
|
||||
conformance_violations,
|
||||
conformance_bytes,
|
||||
conformance_iat_ms,
|
||||
registry,
|
||||
}
|
||||
}
|
||||
@@ -230,10 +336,7 @@ impl RelayMetrics {
|
||||
.with_label_values(&[session_id])
|
||||
.inc_by(underruns - cur_underruns as u64);
|
||||
}
|
||||
let cur_overruns = self
|
||||
.session_overruns
|
||||
.with_label_values(&[session_id])
|
||||
.get();
|
||||
let cur_overruns = self.session_overruns.with_label_values(&[session_id]).get();
|
||||
if overruns > cur_overruns as u64 {
|
||||
self.session_overruns
|
||||
.with_label_values(&[session_id])
|
||||
@@ -274,6 +377,45 @@ impl RelayMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
/// Record conformance-related metrics for a single received packet.
|
||||
///
|
||||
/// * `header` — the media header (provides codec_id and media_type).
|
||||
/// * `payload_len` — payload length in bytes.
|
||||
/// * `iat_ms` — inter-arrival time since the previous packet.
|
||||
/// * `violation` — `Some(Violation)` if the packet triggered a conformance
|
||||
/// limit; `None` for clean packets.
|
||||
pub fn record_conformance(
|
||||
&self,
|
||||
header: &MediaHeader,
|
||||
payload_len: usize,
|
||||
iat_ms: u64,
|
||||
violation: Option<Violation>,
|
||||
) {
|
||||
let media_type = format!("{:?}", header.media_type);
|
||||
let bytes = (MediaHeader::WIRE_SIZE + payload_len) as f64;
|
||||
self.conformance_bytes
|
||||
.with_label_values(&[&media_type])
|
||||
.observe(bytes);
|
||||
self.conformance_iat_ms
|
||||
.with_label_values(&[&media_type])
|
||||
.observe(iat_ms as f64);
|
||||
|
||||
if let Some(v) = violation {
|
||||
let tier = match v {
|
||||
Violation::BitrateExceeded => "A",
|
||||
Violation::PacketRateExceeded => "B",
|
||||
Violation::TimestampDrift => "C",
|
||||
Violation::PayloadSizeExceeded => "D",
|
||||
Violation::RateCapExceeded => "E",
|
||||
};
|
||||
let codec_id = format!("{:?}", header.codec_id);
|
||||
let verdict = format!("{:?}", v);
|
||||
self.conformance_violations
|
||||
.with_label_values(&[tier, &codec_id, &media_type, &verdict])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all per-session label values for a disconnected session.
|
||||
pub fn remove_session_metrics(&self, session_id: &str) {
|
||||
let _ = self.session_buffer_depth.remove_label_values(&[session_id]);
|
||||
@@ -284,7 +426,9 @@ impl RelayMetrics {
|
||||
let _ = self
|
||||
.session_dred_reconstructions
|
||||
.remove_label_values(&[session_id]);
|
||||
let _ = self.session_classical_plc.remove_label_values(&[session_id]);
|
||||
let _ = self
|
||||
.session_classical_plc
|
||||
.remove_label_values(&[session_id]);
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying Prometheus registry.
|
||||
@@ -298,7 +442,9 @@ impl RelayMetrics {
|
||||
let encoder = TextEncoder::new();
|
||||
let metric_families = self.registry.gather();
|
||||
let mut buffer = Vec::new();
|
||||
encoder.encode(&metric_families, &mut buffer).expect("encode");
|
||||
encoder
|
||||
.encode(&metric_families, &mut buffer)
|
||||
.expect("encode");
|
||||
String::from_utf8(buffer).expect("utf8")
|
||||
}
|
||||
}
|
||||
@@ -310,7 +456,7 @@ pub async fn serve_metrics(
|
||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||
route_resolver: Option<Arc<crate::route::RouteResolver>>,
|
||||
) {
|
||||
use axum::{extract::Path, routing::get, Router};
|
||||
use axum::{Router, extract::Path, routing::get};
|
||||
|
||||
let metrics_clone = metrics.clone();
|
||||
let presence_all = presence.clone();
|
||||
@@ -454,8 +600,8 @@ mod tests {
|
||||
fn session_quality_update() {
|
||||
let m = RelayMetrics::new();
|
||||
let report = QualityReport {
|
||||
loss_pct: 128, // ~50%
|
||||
rtt_4ms: 25, // 100ms
|
||||
loss_pct: 128, // ~50%
|
||||
rtt_4ms: 25, // 100ms
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
};
|
||||
|
||||
@@ -11,11 +11,11 @@
|
||||
use tracing::{debug, info};
|
||||
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::packet::{MediaHeader, MediaPacket};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{FecDecoder, FecEncoder, QualityController};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
/// Configuration for a relay pipeline instance.
|
||||
pub struct PipelineConfig {
|
||||
@@ -51,7 +51,7 @@ pub struct RelayPipeline {
|
||||
/// Current quality profile.
|
||||
profile: QualityProfile,
|
||||
/// Outbound sequence counter.
|
||||
out_seq: u16,
|
||||
out_seq: u32,
|
||||
/// Packets processed count.
|
||||
stats: PipelineStats,
|
||||
}
|
||||
@@ -111,8 +111,8 @@ impl RelayPipeline {
|
||||
let header = &packet.header;
|
||||
let _ = self.fec_decoder.add_symbol(
|
||||
header.fec_block,
|
||||
header.fec_symbol,
|
||||
header.is_repair,
|
||||
header.fec_block >> 8,
|
||||
header.is_repair(),
|
||||
&packet.payload,
|
||||
);
|
||||
|
||||
@@ -128,22 +128,21 @@ impl RelayPipeline {
|
||||
for (i, frame) in frames.into_iter().enumerate() {
|
||||
let reconstructed = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: wzp_proto::MediaType::Audio,
|
||||
codec_id: header.codec_id,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: header.fec_ratio_encoded,
|
||||
stream_id: 0,
|
||||
fec_ratio: header.fec_ratio,
|
||||
// Reconstruct seq from block + symbol index
|
||||
seq: (header.fec_block as u16)
|
||||
.wrapping_mul(self.profile.frames_per_block as u16)
|
||||
.wrapping_add(i as u16),
|
||||
timestamp: header
|
||||
.timestamp
|
||||
.wrapping_add((i as u32) * (header.codec_id.frame_duration_ms() as u32)),
|
||||
fec_block: header.fec_block,
|
||||
fec_symbol: i as u8,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
seq: (header.fec_block as u32)
|
||||
.wrapping_mul(self.profile.frames_per_block as u32)
|
||||
.wrapping_add(i as u32),
|
||||
timestamp: header.timestamp.wrapping_add(
|
||||
(i as u32) * (header.codec_id.frame_duration_ms() as u32),
|
||||
),
|
||||
fec_block: u16::from((header.fec_block & 0xFF) as u8)
|
||||
| (u16::from(i as u8) << 8),
|
||||
},
|
||||
payload: bytes::Bytes::from(frame),
|
||||
quality_report: None,
|
||||
@@ -191,19 +190,16 @@ impl RelayPipeline {
|
||||
for (sym_idx, repair_data) in repairs {
|
||||
let repair_packet = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: true,
|
||||
version: 2,
|
||||
flags: MediaHeader::FLAG_REPAIR,
|
||||
media_type: wzp_proto::MediaType::Audio,
|
||||
codec_id: packet.header.codec_id,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
||||
self.profile.fec_ratio,
|
||||
),
|
||||
stream_id: 0,
|
||||
fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||
seq: self.out_seq,
|
||||
timestamp: packet.header.timestamp,
|
||||
fec_block: self.fec_encoder.current_block_id(),
|
||||
fec_symbol: sym_idx,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
fec_block: u16::from(self.fec_encoder.current_block_id())
|
||||
| (u16::from(sym_idx) << 8),
|
||||
},
|
||||
payload: bytes::Bytes::from(repair_data),
|
||||
quality_report: None,
|
||||
@@ -232,23 +228,21 @@ impl RelayPipeline {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::CodecId;
|
||||
use bytes::Bytes;
|
||||
use wzp_proto::CodecId;
|
||||
|
||||
fn make_media_packet(seq: u16, block: u8, symbol: u8) -> MediaPacket {
|
||||
fn make_media_packet(seq: u32, block: u8, symbol: u8) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: wzp_proto::MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq as u32 * 20,
|
||||
fec_block: block,
|
||||
fec_symbol: symbol,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
timestamp: seq * 20,
|
||||
fec_block: u16::from(block) | (u16::from(symbol) << 8),
|
||||
},
|
||||
payload: Bytes::from(vec![seq as u8; 60]),
|
||||
quality_report: None,
|
||||
@@ -283,7 +277,7 @@ mod tests {
|
||||
|
||||
// Feed 5 packets (one full block)
|
||||
let mut total_out = 0;
|
||||
for i in 0..5u16 {
|
||||
for i in 0..5u32 {
|
||||
let pkt = make_media_packet(i, 0, i as u8);
|
||||
let out = pipeline.prepare_outbound(pkt);
|
||||
total_out += out.len();
|
||||
|
||||
@@ -63,6 +63,12 @@ pub struct PresenceRegistry {
|
||||
peers: HashMap<SocketAddr, PeerRelay>,
|
||||
}
|
||||
|
||||
impl Default for PresenceRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PresenceRegistry {
|
||||
/// Create an empty registry.
|
||||
pub fn new() -> Self {
|
||||
@@ -74,13 +80,21 @@ impl PresenceRegistry {
|
||||
}
|
||||
|
||||
/// Register a fingerprint as locally connected (called after auth + handshake).
|
||||
pub fn register_local(&mut self, fingerprint: &str, alias: Option<String>, room: Option<String>) {
|
||||
self.local.insert(fingerprint.to_string(), LocalPresence {
|
||||
fingerprint: fingerprint.to_string(),
|
||||
alias,
|
||||
connected_at: Instant::now(),
|
||||
room,
|
||||
});
|
||||
pub fn register_local(
|
||||
&mut self,
|
||||
fingerprint: &str,
|
||||
alias: Option<String>,
|
||||
room: Option<String>,
|
||||
) {
|
||||
self.local.insert(
|
||||
fingerprint.to_string(),
|
||||
LocalPresence {
|
||||
fingerprint: fingerprint.to_string(),
|
||||
alias,
|
||||
connected_at: Instant::now(),
|
||||
room,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Unregister a locally connected fingerprint (called on disconnect).
|
||||
@@ -98,11 +112,14 @@ impl PresenceRegistry {
|
||||
|
||||
// Insert new remote entries
|
||||
for fp in &fingerprints {
|
||||
self.remote.insert(fp.clone(), RemotePresence {
|
||||
fingerprint: fp.clone(),
|
||||
relay_addr: addr,
|
||||
last_seen: now,
|
||||
});
|
||||
self.remote.insert(
|
||||
fp.clone(),
|
||||
RemotePresence {
|
||||
fingerprint: fp.clone(),
|
||||
relay_addr: addr,
|
||||
last_seen: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Update the peer record
|
||||
@@ -156,7 +173,8 @@ impl PresenceRegistry {
|
||||
self.remote.retain(|_, rp| rp.last_seen > cutoff);
|
||||
|
||||
// Expire peer relay records and their fingerprint sets
|
||||
let stale_peers: Vec<SocketAddr> = self.peers
|
||||
let stale_peers: Vec<SocketAddr> = self
|
||||
.peers
|
||||
.iter()
|
||||
.filter(|(_, p)| p.last_update <= cutoff)
|
||||
.map(|(addr, _)| *addr)
|
||||
@@ -280,13 +298,15 @@ mod tests {
|
||||
let all = reg.all_known();
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let local_entries: Vec<_> = all.iter()
|
||||
let local_entries: Vec<_> = all
|
||||
.iter()
|
||||
.filter(|(_, loc)| *loc == PresenceLocation::Local)
|
||||
.collect();
|
||||
assert_eq!(local_entries.len(), 1);
|
||||
assert_eq!(local_entries[0].0, "local1");
|
||||
|
||||
let remote_entries: Vec<_> = all.iter()
|
||||
let remote_entries: Vec<_> = all
|
||||
.iter()
|
||||
.filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_)))
|
||||
.collect();
|
||||
assert_eq!(remote_entries.len(), 1);
|
||||
|
||||
@@ -13,7 +13,7 @@ use prometheus::{Gauge, IntGauge, Opts, Registry};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use wzp_proto::{MediaTransport, SignalMessage};
|
||||
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||
|
||||
/// Configuration for a single probe target.
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -43,8 +43,7 @@ impl ProbeMetrics {
|
||||
/// Register probe metrics with the given `target` label value.
|
||||
pub fn register(target: &str, registry: &Registry) -> Self {
|
||||
let rtt_ms = Gauge::with_opts(
|
||||
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms")
|
||||
.const_label("target", target),
|
||||
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms").const_label("target", target),
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
@@ -66,9 +65,15 @@ impl ProbeMetrics {
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
registry.register(Box::new(rtt_ms.clone())).expect("register");
|
||||
registry.register(Box::new(loss_pct.clone())).expect("register");
|
||||
registry.register(Box::new(jitter_ms.clone())).expect("register");
|
||||
registry
|
||||
.register(Box::new(rtt_ms.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(loss_pct.clone()))
|
||||
.expect("register");
|
||||
registry
|
||||
.register(Box::new(jitter_ms.clone()))
|
||||
.expect("register");
|
||||
registry.register(Box::new(up.clone())).expect("register");
|
||||
|
||||
Self {
|
||||
@@ -168,7 +173,11 @@ impl ProbeRunner {
|
||||
) -> Self {
|
||||
let target_str = config.target.to_string();
|
||||
let metrics = ProbeMetrics::register(&target_str, registry);
|
||||
Self { config, metrics, presence }
|
||||
Self {
|
||||
config,
|
||||
metrics,
|
||||
presence,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the probe forever. This function never returns under normal operation.
|
||||
@@ -198,13 +207,8 @@ impl ProbeRunner {
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(
|
||||
&endpoint,
|
||||
self.config.target,
|
||||
"_probe",
|
||||
client_cfg,
|
||||
)
|
||||
.await?;
|
||||
let conn =
|
||||
wzp_transport::connect(&endpoint, self.config.target, "_probe", client_cfg).await?;
|
||||
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
self.metrics.up.set(1);
|
||||
@@ -225,7 +229,7 @@ impl ProbeRunner {
|
||||
let recv_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
match recv_transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::Pong { timestamp_ms })) => {
|
||||
Ok(Some(SignalMessage::Pong { timestamp_ms, .. })) => {
|
||||
let now_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
@@ -237,11 +241,16 @@ impl ProbeRunner {
|
||||
loss_gauge.set(w.loss_pct());
|
||||
jitter_gauge.set(w.jitter_ms());
|
||||
}
|
||||
Ok(Some(SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => {
|
||||
Ok(Some(SignalMessage::PresenceUpdate {
|
||||
fingerprints,
|
||||
relay_addr,
|
||||
..
|
||||
})) => {
|
||||
if let Some(ref reg) = recv_presence {
|
||||
// Parse the relay_addr; fall back to the connection target
|
||||
let addr = relay_addr.parse().unwrap_or(recv_target);
|
||||
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect();
|
||||
let fps: std::collections::HashSet<String> =
|
||||
fingerprints.into_iter().collect();
|
||||
let mut r = reg.lock().await;
|
||||
r.update_peer(addr, fps);
|
||||
}
|
||||
@@ -285,7 +294,10 @@ impl ProbeRunner {
|
||||
}
|
||||
|
||||
if let Err(e) = transport
|
||||
.send_signal(&SignalMessage::Ping { timestamp_ms })
|
||||
.send_signal(&SignalMessage::Ping {
|
||||
version: default_signal_version(),
|
||||
timestamp_ms,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!(target = %self.config.target, "probe ping send error: {e}");
|
||||
@@ -302,6 +314,7 @@ impl ProbeRunner {
|
||||
r.local_fingerprints().into_iter().collect()
|
||||
};
|
||||
let msg = SignalMessage::PresenceUpdate {
|
||||
version: default_signal_version(),
|
||||
fingerprints: fps,
|
||||
relay_addr: self.config.target.to_string(),
|
||||
};
|
||||
@@ -374,10 +387,7 @@ pub fn mesh_summary(registry: &Registry) -> String {
|
||||
let name = family.get_name();
|
||||
for metric in family.get_metric() {
|
||||
// Find the "target" label
|
||||
let target_label = metric
|
||||
.get_label()
|
||||
.iter()
|
||||
.find(|l| l.get_name() == "target");
|
||||
let target_label = metric.get_label().iter().find(|l| l.get_name() == "target");
|
||||
let target = match target_label {
|
||||
Some(l) => l.get_value().to_string(),
|
||||
None => continue,
|
||||
@@ -420,13 +430,11 @@ pub fn mesh_summary(registry: &Registry) -> String {
|
||||
|
||||
/// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp.
|
||||
/// Returns true if the message was a Ping and was handled, false otherwise.
|
||||
pub async fn handle_ping(
|
||||
transport: &wzp_transport::QuinnTransport,
|
||||
msg: &SignalMessage,
|
||||
) -> bool {
|
||||
if let SignalMessage::Ping { timestamp_ms } = msg {
|
||||
pub async fn handle_ping(transport: &wzp_transport::QuinnTransport, msg: &SignalMessage) -> bool {
|
||||
if let SignalMessage::Ping { timestamp_ms, .. } = msg {
|
||||
if let Err(e) = transport
|
||||
.send_signal(&SignalMessage::Pong {
|
||||
version: default_signal_version(),
|
||||
timestamp_ms: *timestamp_ms,
|
||||
})
|
||||
.await
|
||||
@@ -456,9 +464,18 @@ mod tests {
|
||||
encoder.encode(&families, &mut buf).unwrap();
|
||||
let output = String::from_utf8(buf).unwrap();
|
||||
|
||||
assert!(output.contains("wzp_probe_rtt_ms"), "missing wzp_probe_rtt_ms");
|
||||
assert!(output.contains("wzp_probe_loss_pct"), "missing wzp_probe_loss_pct");
|
||||
assert!(output.contains("wzp_probe_jitter_ms"), "missing wzp_probe_jitter_ms");
|
||||
assert!(
|
||||
output.contains("wzp_probe_rtt_ms"),
|
||||
"missing wzp_probe_rtt_ms"
|
||||
);
|
||||
assert!(
|
||||
output.contains("wzp_probe_loss_pct"),
|
||||
"missing wzp_probe_loss_pct"
|
||||
);
|
||||
assert!(
|
||||
output.contains("wzp_probe_jitter_ms"),
|
||||
"missing wzp_probe_jitter_ms"
|
||||
);
|
||||
assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up");
|
||||
assert!(
|
||||
output.contains("target=\"127.0.0.1:4433\""),
|
||||
|
||||
@@ -40,10 +40,7 @@ impl RelayLink {
|
||||
/// should skip normal client auth/handshake for relay-SNI connections.
|
||||
pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> {
|
||||
// Create a client-only endpoint on an OS-assigned port.
|
||||
let endpoint = wzp_transport::create_endpoint(
|
||||
"0.0.0.0:0".parse().unwrap(),
|
||||
None,
|
||||
)?;
|
||||
let endpoint = wzp_transport::create_endpoint("0.0.0.0:0".parse().unwrap(), None)?;
|
||||
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?;
|
||||
@@ -336,10 +333,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn session_forward_signal_roundtrip() {
|
||||
use wzp_proto::SignalMessage;
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
// SessionForward roundtrip
|
||||
let msg = SignalMessage::SessionForward {
|
||||
version: default_signal_version(),
|
||||
session_id: "abcd1234".to_string(),
|
||||
target_fingerprint: "deadbeef".to_string(),
|
||||
source_relay: "10.0.0.1:4433".to_string(),
|
||||
@@ -351,6 +349,7 @@ mod tests {
|
||||
session_id,
|
||||
target_fingerprint,
|
||||
source_relay,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(session_id, "abcd1234");
|
||||
assert_eq!(target_fingerprint, "deadbeef");
|
||||
@@ -361,6 +360,7 @@ mod tests {
|
||||
|
||||
// SessionForwardAck roundtrip
|
||||
let ack = SignalMessage::SessionForwardAck {
|
||||
version: default_signal_version(),
|
||||
session_id: "abcd1234".to_string(),
|
||||
room_name: "relay-room-42".to_string(),
|
||||
};
|
||||
@@ -370,6 +370,7 @@ mod tests {
|
||||
SignalMessage::SessionForwardAck {
|
||||
session_id,
|
||||
room_name,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(session_id, "abcd1234");
|
||||
assert_eq!(room_name, "relay-room-42");
|
||||
@@ -457,17 +458,15 @@ mod tests {
|
||||
|
||||
let pkt = MediaPacket {
|
||||
header: wzp_proto::packet::MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: wzp_proto::MediaType::Audio,
|
||||
codec_id: wzp_proto::CodecId::Opus16k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: 1,
|
||||
timestamp: 100,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: bytes::Bytes::from_static(b"test"),
|
||||
quality_report: None,
|
||||
|
||||
207
crates/wzp-relay/src/response_policy.rs
Normal file
207
crates/wzp-relay/src/response_policy.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Tier G response policy — maps conformance verdicts to enforcement actions.
|
||||
//!
|
||||
//! Actions:
|
||||
//! - `Legitimate` → no action
|
||||
//! - `Suspect` → tighten Tier E quota, emit metric
|
||||
//! - `Abusive` → typed Hangup + 1 h fingerprint cool-down
|
||||
//! - `RepeatAbusive` → relay-local block 24 h
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use wzp_proto::packet::{HangupReason, ViolationCode};
|
||||
|
||||
use crate::verdict::Verdict;
|
||||
|
||||
/// Enforcement action recommended by the response policy.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Action {
|
||||
/// Pass through unchanged.
|
||||
Allow,
|
||||
/// Throttle to tighter quota (Tier E).
|
||||
Throttle,
|
||||
/// Close the session with a typed Hangup signal.
|
||||
Close { reason: HangupReason },
|
||||
/// Block the fingerprint from joining any room for 24 h.
|
||||
Block,
|
||||
}
|
||||
|
||||
/// Tracks fingerprint-level abuse history and applies escalation.
|
||||
pub struct ResponsePolicy {
|
||||
/// `(fingerprint, violation_code)` → last abusive instant.
|
||||
cooldowns: HashMap<(String, ViolationCode), Instant>,
|
||||
/// Block duration for repeat abuse.
|
||||
block_duration: Duration,
|
||||
}
|
||||
|
||||
impl ResponsePolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cooldowns: HashMap::new(),
|
||||
block_duration: Duration::from_secs(86400), // 24 h
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a verdict and produce the corresponding [`Action`].
|
||||
///
|
||||
/// `fingerprint` is the participant's identity string (or IP as fallback).
|
||||
/// `code` is the specific violation type that triggered the verdict.
|
||||
pub fn evaluate(&mut self, fingerprint: &str, code: ViolationCode, verdict: Verdict) -> Action {
|
||||
match verdict {
|
||||
Verdict::Legitimate => Action::Allow,
|
||||
Verdict::Suspect => Action::Throttle,
|
||||
Verdict::Abusive => {
|
||||
let key = (fingerprint.to_string(), code);
|
||||
let now = Instant::now();
|
||||
|
||||
// Check if this fingerprint was already abusive recently.
|
||||
let is_repeat = self
|
||||
.cooldowns
|
||||
.get(&key)
|
||||
.map(|last| now.duration_since(*last) < self.block_duration)
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_repeat {
|
||||
Action::Block
|
||||
} else {
|
||||
self.cooldowns.insert(key, now);
|
||||
Action::Close {
|
||||
reason: HangupReason::PolicyViolation {
|
||||
code,
|
||||
reason: format!("Tier G enforcement: {code:?}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the fingerprint is currently blocked (repeat abuse).
|
||||
pub fn is_blocked(&self, fingerprint: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
self.cooldowns.iter().any(|((fp, _), last)| {
|
||||
fp == fingerprint && now.duration_since(*last) < self.block_duration
|
||||
})
|
||||
}
|
||||
|
||||
/// Clean up expired cooldown entries.
|
||||
pub fn prune(&mut self) {
|
||||
let now = Instant::now();
|
||||
self.cooldowns
|
||||
.retain(|_, last| now.duration_since(*last) < self.block_duration);
|
||||
}
|
||||
|
||||
/// Number of tracked cooldown entries.
|
||||
pub fn len(&self) -> usize {
|
||||
self.cooldowns.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.cooldowns.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ResponsePolicy {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn legitimate_allowed() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
assert_eq!(
|
||||
policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Legitimate),
|
||||
Action::Allow
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn suspect_throttled() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
assert_eq!(
|
||||
policy.evaluate("alice", ViolationCode::Entropy, Verdict::Suspect),
|
||||
Action::Throttle
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn abusive_gets_close() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
assert!(
|
||||
matches!(action, Action::Close { .. }),
|
||||
"first-time abuse should close session"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repeat_abusive_gets_block() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
// First abuse
|
||||
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
// Second abuse within window → block
|
||||
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
assert_eq!(action, Action::Block, "repeat abuse should block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_violation_codes_are_independent() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
// Abuse on bitrate
|
||||
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
// Abuse on entropy is treated as first-time for that code
|
||||
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
|
||||
assert!(
|
||||
matches!(action, Action::Close { .. }),
|
||||
"different violation code should not trigger repeat"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_blocked_true_after_repeat() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
assert!(policy.is_blocked("alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_blocked_false_for_legitimate() {
|
||||
let policy = ResponsePolicy::new();
|
||||
assert!(!policy.is_blocked("alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prune_removes_expired() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||
assert_eq!(policy.len(), 1);
|
||||
// Manually expire by moving cooldown back
|
||||
policy.cooldowns.insert(
|
||||
("alice".to_string(), ViolationCode::Bitrate),
|
||||
Instant::now() - Duration::from_secs(90000),
|
||||
);
|
||||
policy.prune();
|
||||
assert!(policy.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_reason_contains_code() {
|
||||
let mut policy = ResponsePolicy::new();
|
||||
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
|
||||
match action {
|
||||
Action::Close { reason } => match reason {
|
||||
HangupReason::PolicyViolation { code, .. } => {
|
||||
assert_eq!(code, ViolationCode::Entropy);
|
||||
}
|
||||
other => panic!("expected PolicyViolation, got {other:?}"),
|
||||
},
|
||||
other => panic!("expected Close, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -97,14 +97,13 @@ impl RouteResolver {
|
||||
}
|
||||
|
||||
/// Build a JSON-serializable route response for the HTTP API.
|
||||
pub fn route_json(
|
||||
&self,
|
||||
fingerprint: &str,
|
||||
route: &Route,
|
||||
) -> serde_json::Value {
|
||||
pub fn route_json(&self, fingerprint: &str, route: &Route) -> serde_json::Value {
|
||||
let (route_type, relay_chain) = match route {
|
||||
Route::Local => ("local", vec![self.local_addr.to_string()]),
|
||||
Route::DirectPeer(addr) => ("direct_peer", vec![self.local_addr.to_string(), addr.to_string()]),
|
||||
Route::DirectPeer(addr) => (
|
||||
"direct_peer",
|
||||
vec![self.local_addr.to_string(), addr.to_string()],
|
||||
),
|
||||
Route::Chain(chain) => {
|
||||
let mut addrs = vec![self.local_addr.to_string()];
|
||||
addrs.extend(chain.iter().map(|a| a.to_string()));
|
||||
@@ -184,7 +183,10 @@ mod tests {
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
// Local lookup works via multi-hop
|
||||
assert_eq!(resolver.resolve_multi_hop(®, "local_fp", 3), Route::Local);
|
||||
assert_eq!(
|
||||
resolver.resolve_multi_hop(®, "local_fp", 3),
|
||||
Route::Local
|
||||
);
|
||||
// Remote lookup works via multi-hop
|
||||
assert_eq!(
|
||||
resolver.resolve_multi_hop(®, "remote_fp", 3),
|
||||
@@ -199,9 +201,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn route_query_signal_roundtrip() {
|
||||
use wzp_proto::SignalMessage;
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
let query = SignalMessage::RouteQuery {
|
||||
version: default_signal_version(),
|
||||
fingerprint: "aabbccdd".to_string(),
|
||||
ttl: 3,
|
||||
};
|
||||
@@ -209,11 +212,12 @@ mod tests {
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(
|
||||
decoded,
|
||||
SignalMessage::RouteQuery { ref fingerprint, ttl }
|
||||
SignalMessage::RouteQuery { ref fingerprint, ttl, ..}
|
||||
if fingerprint == "aabbccdd" && ttl == 3
|
||||
));
|
||||
|
||||
let response = SignalMessage::RouteResponse {
|
||||
version: default_signal_version(),
|
||||
fingerprint: "aabbccdd".to_string(),
|
||||
found: true,
|
||||
relay_chain: vec!["10.0.0.1:4433".to_string(), "10.0.0.2:4433".to_string()],
|
||||
@@ -222,7 +226,7 @@ mod tests {
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(
|
||||
decoded,
|
||||
SignalMessage::RouteResponse { ref fingerprint, found, ref relay_chain }
|
||||
SignalMessage::RouteResponse { ref fingerprint, found, ref relay_chain, ..}
|
||||
if fingerprint == "aabbccdd" && found && relay_chain.len() == 2
|
||||
));
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user