Compare commits
183 Commits
6f43415285
...
video-usab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12020b019c | ||
|
|
3ea25a0656 | ||
|
|
112472609e | ||
|
|
9a7745978b | ||
|
|
f85efb9576 | ||
|
|
31b2caa54d | ||
|
|
079e21e174 | ||
|
|
e676641538 | ||
|
|
9713efc404 | ||
|
|
8415804a1a | ||
|
|
f65b399a21 | ||
|
|
3437a6bd11 | ||
|
|
15eb00ed5e | ||
|
|
0c2297a2b7 | ||
|
|
a08a37b5eb | ||
|
|
f6ace54556 | ||
|
|
47baa1a765 | ||
|
|
ee654cd1ef | ||
|
|
d2046060b5 | ||
|
|
0b7bf1b385 | ||
|
|
e8f139588a | ||
|
|
0115b11de7 | ||
|
|
fa812a17d9 | ||
|
|
8d6b168f1b | ||
|
|
ca164ada5c | ||
|
|
2d58bae9ba | ||
|
|
e1ca6ca6e6 | ||
|
|
06d28a9280 | ||
|
|
d57ebe3d2c | ||
|
|
7eca79846f | ||
|
|
25b3278d31 | ||
|
|
cbc3a8d37e | ||
|
|
1329abbeba | ||
|
|
e8cab25eda | ||
|
|
c41ced53e1 | ||
|
|
7fd66be6c8 | ||
|
|
8002acaf09 | ||
|
|
06253fdeeb | ||
|
|
01f55caa96 | ||
|
|
0f93a2b745 | ||
|
|
2b93bd4b45 | ||
|
|
bc021517c0 | ||
|
|
739bdaf3ab | ||
|
|
bc1668ed96 | ||
|
|
77b036439b | ||
|
|
0ebc73ab13 | ||
|
|
394987a349 | ||
|
|
2aa6582585 | ||
|
|
ca987d547c | ||
|
|
5a13f12334 | ||
|
|
b0a3b1f18e | ||
|
|
32c07d1b61 | ||
|
|
5d05b021aa | ||
|
|
4ac62d99e0 | ||
|
|
4ebb2dac2d | ||
|
|
52a6f5e048 | ||
|
|
15af58a95d | ||
|
|
ed8a7ae5aa | ||
|
|
12b0d9738f | ||
|
|
f78794f4b6 | ||
|
|
f3e3ee5ed0 | ||
|
|
f28f39d814 | ||
|
|
1e729e4b1d | ||
|
|
086d0a4845 | ||
|
|
9334aa5ccd | ||
|
|
553c8a4ce1 | ||
|
|
8d8dddbd35 | ||
|
|
f16d650721 | ||
|
|
31f2fdef1e | ||
|
|
fc9908cd4c | ||
|
|
517d0ebfe0 | ||
|
|
cf4940417e | ||
|
|
ffded2a913 | ||
|
|
283edd38eb | ||
|
|
fdfaed5390 | ||
|
|
dbbab0decf | ||
|
|
5fda5ecc52 | ||
|
|
2bbb664df4 | ||
|
|
2f1a9f74d5 | ||
|
|
b197651557 | ||
|
|
9c41d1acdd | ||
|
|
e34c40dc0f | ||
|
|
c48cb6fbcb | ||
|
|
2e0bdc5904 | ||
|
|
276ecc660e | ||
|
|
001d94f9ae | ||
|
|
36b0421d68 | ||
|
|
828fbea2ea | ||
|
|
cc5aef2534 | ||
|
|
397f9d2141 | ||
|
|
410c2a4335 | ||
|
|
81042ac190 | ||
|
|
e177e63843 | ||
|
|
1f7d130de9 | ||
|
|
3356ba94c6 | ||
|
|
bb153a331d | ||
|
|
490d2d31c6 | ||
|
|
db69f7e9d1 | ||
|
|
f1b86e0fed | ||
|
|
8454835c18 | ||
|
|
017c371611 | ||
|
|
3220bd6151 | ||
|
|
e73f8a7150 | ||
|
|
1b4f7b0772 | ||
|
|
f3398adb95 | ||
|
|
54c1a35186 | ||
|
|
3de56cf1f9 | ||
|
|
fe1f9484bd | ||
|
|
0ef1f574ff | ||
|
|
b1c5837495 | ||
|
|
6f81487778 | ||
|
|
5cdb50160a | ||
|
|
30d26fc7f6 | ||
|
|
c93d302656 | ||
|
|
9680b6ff34 | ||
|
|
6b15b8f97c | ||
|
|
6385b93391 | ||
|
|
6eb94f079d | ||
|
|
5580b794a4 | ||
|
|
7c9ede9227 | ||
|
|
e8866c6632 | ||
|
|
8c6e88ea68 | ||
|
|
ffb92237be | ||
|
|
6af0539a72 | ||
|
|
217567383d | ||
|
|
98ed981805 | ||
|
|
01a3133544 | ||
|
|
25471c694f | ||
|
|
a058a83c91 | ||
|
|
9b8013ba7f | ||
|
|
defd8eab07 | ||
|
|
cc23e829b2 | ||
|
|
18c204c1ff | ||
|
|
1120c7b579 | ||
|
|
7e7391fdbb | ||
|
|
aa0362f318 | ||
|
|
bb23976076 | ||
|
|
18e5e75f33 | ||
|
|
488efcb614 | ||
|
|
8c360186df | ||
|
|
f06f9073ae | ||
|
|
6c49d7436f | ||
|
|
1de280fe04 | ||
|
|
bc6d327ebb | ||
|
|
c478224d67 | ||
|
|
16dcc75514 | ||
|
|
db5751985e | ||
|
|
c0dd6c06ff | ||
|
|
6805caae0e | ||
|
|
5a03da72d3 | ||
|
|
e3e63a40a0 | ||
|
|
7b4bce69d5 | ||
|
|
ec1bdf3cd5 | ||
|
|
ee14862376 | ||
|
|
f83361895e | ||
|
|
0857d190ed | ||
|
|
5d431c0721 | ||
|
|
8fcf1be341 | ||
|
|
9377a9009c | ||
|
|
4471797edf | ||
|
|
425c67a08a | ||
|
|
88ca3e099a | ||
|
|
1e82811cc1 | ||
|
|
81b5522942 | ||
|
|
d539a6dfb9 | ||
|
|
ba12aae439 | ||
|
|
fdb78e08bd | ||
|
|
3a51db998a | ||
|
|
a52b011fb5 | ||
|
|
2514151a89 | ||
|
|
f265fd772d | ||
|
|
9ae9441de4 | ||
|
|
d9e7e72978 | ||
|
|
8ff0c548a7 | ||
|
|
f17420aa98 | ||
|
|
d424515542 | ||
|
|
ea5fc17c34 | ||
|
|
1a7dd935ee | ||
|
|
a7c2261b70 | ||
|
|
eca0bb7531 | ||
|
|
d249b32ee5 | ||
|
|
22045bc5e6 | ||
|
|
766c9df442 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -12,6 +12,11 @@ npm-debug.log*
|
|||||||
yarn-debug.log*
|
yarn-debug.log*
|
||||||
yarn-error.log*
|
yarn-error.log*
|
||||||
dev-debug.log
|
dev-debug.log
|
||||||
|
|
||||||
|
# Debug frame dump artifacts
|
||||||
|
android-frame-dumps/
|
||||||
|
wzp-frame-dumps.tar
|
||||||
|
|
||||||
# Dependency directories
|
# Dependency directories
|
||||||
node_modules/
|
node_modules/
|
||||||
# Environment variables
|
# Environment variables
|
||||||
|
|||||||
14
.gitleaks.toml
Normal file
14
.gitleaks.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[extend]
|
||||||
|
useDefault = true
|
||||||
|
|
||||||
|
[[allowlists]]
|
||||||
|
description = "Pre-existing historical findings already on fj/main and github/main. The two PASTE_AUTH tokens in scripts/build.sh and scripts/build-linux-notify.sh are real — rotate if those endpoints still authenticate; this allowlist only silences the pre-push hook, it does not remove the exposure."
|
||||||
|
commits = [
|
||||||
|
# wzp-crypto module doc: false positive on "SHA-256(Ed25519 pub)[:16]"
|
||||||
|
"51e893590c1b9fa49e9f6ae5c96c26deb58f353b",
|
||||||
|
# build.sh PASTE_AUTH (paste.tbs.amn.gg)
|
||||||
|
"bd6733b2e5d76b5259020f1c30a5223a9773b6aa",
|
||||||
|
# build-linux-notify Authorization header (paste.dk.manko.yoga)
|
||||||
|
"6d776097c83bc6fbe3f3565e080513d8af93b550",
|
||||||
|
"7751439e2bca9eacf2c30929c8124a4eb6136df2",
|
||||||
|
]
|
||||||
1523
Cargo.lock
generated
1523
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ members = [
|
|||||||
"crates/wzp-web",
|
"crates/wzp-web",
|
||||||
"crates/wzp-android",
|
"crates/wzp-android",
|
||||||
"crates/wzp-native",
|
"crates/wzp-native",
|
||||||
|
"crates/wzp-video",
|
||||||
"desktop/src-tauri",
|
"desktop/src-tauri",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
1
android.sh
Normal file
1
android.sh
Normal file
@@ -0,0 +1 @@
|
|||||||
|
./scripts/android-build-async.sh --init
|
||||||
@@ -28,6 +28,7 @@ libc = "0.2"
|
|||||||
jni = { version = "0.21", default-features = false }
|
jni = { version = "0.21", default-features = false }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
rustls = { version = "0.23", default-features = false, features = ["ring"] }
|
rustls = { version = "0.23", default-features = false, features = ["ring"] }
|
||||||
|
[target.'cfg(target_os = "android")'.dependencies]
|
||||||
tracing-android = "0.2"
|
tracing-android = "0.2"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|||||||
@@ -65,9 +65,8 @@ fn main() {
|
|||||||
} else {
|
} else {
|
||||||
"aarch64-linux-android"
|
"aarch64-linux-android"
|
||||||
};
|
};
|
||||||
let lib_dir = format!(
|
let lib_dir =
|
||||||
"{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}"
|
format!("{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}");
|
||||||
);
|
|
||||||
println!("cargo:rustc-link-search=native={lib_dir}");
|
println!("cargo:rustc-link-search=native={lib_dir}");
|
||||||
|
|
||||||
// Copy libc++_shared.so to the jniLibs directory
|
// Copy libc++_shared.so to the jniLibs directory
|
||||||
@@ -82,9 +81,7 @@ fn main() {
|
|||||||
};
|
};
|
||||||
// Try to copy to the Gradle jniLibs directory
|
// Try to copy to the Gradle jniLibs directory
|
||||||
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
|
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
|
||||||
let jni_dir = format!(
|
let jni_dir = format!("{manifest}/../../android/app/src/main/jniLibs/{jni_abi}");
|
||||||
"{manifest}/../../android/app/src/main/jniLibs/{jni_abi}"
|
|
||||||
);
|
|
||||||
if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
|
if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
|
||||||
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
|
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
|
||||||
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
|
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
|
||||||
@@ -127,7 +124,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
|||||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||||
let oboe_dir = out_dir.join("oboe");
|
let oboe_dir = out_dir.join("oboe");
|
||||||
|
|
||||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
if oboe_dir
|
||||||
|
.join("include")
|
||||||
|
.join("oboe")
|
||||||
|
.join("Oboe.h")
|
||||||
|
.exists()
|
||||||
|
{
|
||||||
return Some(oboe_dir);
|
return Some(oboe_dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,7 +145,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
|||||||
|
|
||||||
match status {
|
match status {
|
||||||
Ok(s) if s.success() => {
|
Ok(s) if s.success() => {
|
||||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
if oboe_dir
|
||||||
|
.join("include")
|
||||||
|
.join("oboe")
|
||||||
|
.join("Oboe.h")
|
||||||
|
.exists()
|
||||||
|
{
|
||||||
Some(oboe_dir)
|
Some(oboe_dir)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -326,7 +326,10 @@ pub fn pin_to_big_core() {
|
|||||||
&set,
|
&set,
|
||||||
);
|
);
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error());
|
warn!(
|
||||||
|
"sched_setaffinity failed: {}",
|
||||||
|
std::io::Error::last_os_error()
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
info!(start, num_cpus, "pinned to big cores");
|
info!(start, num_cpus, "pinned to big cores");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,8 @@ impl AudioRing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.write_pos.store(w.wrapping_add(count), Ordering::Release);
|
self.write_pos
|
||||||
|
.store(w.wrapping_add(count), Ordering::Release);
|
||||||
count
|
count
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +113,8 @@ impl AudioRing {
|
|||||||
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
||||||
}
|
}
|
||||||
|
|
||||||
self.read_pos.store(r.wrapping_add(count), Ordering::Release);
|
self.read_pos
|
||||||
|
.store(r.wrapping_add(count), Ordering::Release);
|
||||||
count
|
count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
|
|||||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||||
use wzp_proto::{
|
use wzp_proto::{
|
||||||
AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
|
AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
|
||||||
MediaHeader, MediaPacket, MediaTransport, QualityController, QualityProfile, SignalMessage,
|
MediaHeader, MediaPacket, MediaTransport, MediaType, QualityController, QualityProfile,
|
||||||
|
SignalMessage, default_signal_version,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::audio_ring::AudioRing;
|
use crate::audio_ring::AudioRing;
|
||||||
@@ -46,7 +47,11 @@ const PROFILES: [QualityProfile; 6] = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
fn profile_to_index(p: &QualityProfile) -> u8 {
|
fn profile_to_index(p: &QualityProfile) -> u8 {
|
||||||
PROFILES.iter().position(|pp| pp.codec == p.codec).map(|i| i as u8).unwrap_or(3)
|
PROFILES
|
||||||
|
.iter()
|
||||||
|
.position(|pp| pp.codec == p.codec)
|
||||||
|
.map(|i| i as u8)
|
||||||
|
.unwrap_or(3)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_to_profile(idx: u8) -> Option<QualityProfile> {
|
fn index_to_profile(idx: u8) -> Option<QualityProfile> {
|
||||||
@@ -149,9 +154,10 @@ impl WzpEngine {
|
|||||||
.enable_all()
|
.enable_all()
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
let relay_addr: SocketAddr = config.relay_addr.parse().map_err(|e| {
|
let relay_addr: SocketAddr = config
|
||||||
anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr)
|
.relay_addr
|
||||||
})?;
|
.parse()
|
||||||
|
.map_err(|e| anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr))?;
|
||||||
|
|
||||||
let room = config.room.clone();
|
let room = config.room.clone();
|
||||||
let identity_seed = config.identity_seed;
|
let identity_seed = config.identity_seed;
|
||||||
@@ -165,7 +171,16 @@ impl WzpEngine {
|
|||||||
|
|
||||||
let state_clone = state.clone();
|
let state_clone = state.clone();
|
||||||
runtime.block_on(async move {
|
runtime.block_on(async move {
|
||||||
if let Err(e) = run_call(relay_addr, &room, &identity_seed, profile, auto_profile, alias.as_deref(), state_clone).await
|
if let Err(e) = run_call(
|
||||||
|
relay_addr,
|
||||||
|
&room,
|
||||||
|
&identity_seed,
|
||||||
|
profile,
|
||||||
|
auto_profile,
|
||||||
|
alias.as_deref(),
|
||||||
|
state_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
error!("call failed: {e}");
|
error!("call failed: {e}");
|
||||||
}
|
}
|
||||||
@@ -233,16 +248,21 @@ impl WzpEngine {
|
|||||||
let server_fp = conn
|
let server_fp = conn
|
||||||
.peer_identity()
|
.peer_identity()
|
||||||
.and_then(|id| id.downcast::<Vec<rustls::pki_types::CertificateDer>>().ok())
|
.and_then(|id| id.downcast::<Vec<rustls::pki_types::CertificateDer>>().ok())
|
||||||
.and_then(|certs| certs.first().map(|c| {
|
.and_then(|certs| {
|
||||||
|
certs.first().map(|c| {
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
let mut h = std::collections::hash_map::DefaultHasher::new();
|
let mut h = std::collections::hash_map::DefaultHasher::new();
|
||||||
c.as_ref().hash(&mut h);
|
c.as_ref().hash(&mut h);
|
||||||
format!("{:016x}", h.finish())
|
format!("{:016x}", h.finish())
|
||||||
}))
|
})
|
||||||
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
conn.close(0u32.into(), b"ping");
|
conn.close(0u32.into(), b"ping");
|
||||||
|
|
||||||
Ok::<_, anyhow::Error>(format!(r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#, rtt_ms, server_fp))
|
Ok::<_, anyhow::Error>(format!(
|
||||||
|
r#"{{"rtt_ms":{},"server_fingerprint":"{}"}}"#,
|
||||||
|
rtt_ms, server_fp
|
||||||
|
))
|
||||||
});
|
});
|
||||||
|
|
||||||
// Shutdown runtime cleanly with timeout
|
// Shutdown runtime cleanly with timeout
|
||||||
@@ -301,11 +321,12 @@ impl WzpEngine {
|
|||||||
|
|
||||||
// Auth if token provided
|
// Auth if token provided
|
||||||
if let Some(ref tok) = token {
|
if let Some(ref tok) = token {
|
||||||
let _ = transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await;
|
let _ = transport.send_signal(&SignalMessage::AuthToken { version: default_signal_version(), token: tok.clone() }).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register presence
|
// Register presence
|
||||||
let _ = transport.send_signal(&SignalMessage::RegisterPresence {
|
let _ = transport.send_signal(&SignalMessage::RegisterPresence {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
signature: vec![],
|
signature: vec![],
|
||||||
alias: alias.clone(),
|
alias: alias.clone(),
|
||||||
@@ -330,7 +351,7 @@ impl WzpEngine {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
match transport.recv_signal().await {
|
match transport.recv_signal().await {
|
||||||
Ok(Some(SignalMessage::CallRinging { call_id })) => {
|
Ok(Some(SignalMessage::CallRinging { call_id, ..})) => {
|
||||||
info!(call_id = %call_id, "signal: ringing");
|
info!(call_id = %call_id, "signal: ringing");
|
||||||
let mut stats = signal_state.stats.lock().unwrap();
|
let mut stats = signal_state.stats.lock().unwrap();
|
||||||
stats.state = crate::stats::CallState::Ringing;
|
stats.state = crate::stats::CallState::Ringing;
|
||||||
@@ -392,7 +413,11 @@ impl WzpEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Answer an incoming direct call.
|
/// Answer an incoming direct call.
|
||||||
pub fn answer_call(&self, call_id: &str, mode: wzp_proto::CallAcceptMode) -> Result<(), anyhow::Error> {
|
pub fn answer_call(
|
||||||
|
&self,
|
||||||
|
call_id: &str,
|
||||||
|
mode: wzp_proto::CallAcceptMode,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
let _ = self.state.command_tx.send(EngineCommand::AnswerCall {
|
let _ = self.state.command_tx.send(EngineCommand::AnswerCall {
|
||||||
call_id: call_id.to_string(),
|
call_id: call_id.to_string(),
|
||||||
accept_mode: mode,
|
accept_mode: mode,
|
||||||
@@ -412,7 +437,9 @@ impl WzpEngine {
|
|||||||
/// Stores the type atomically; the recv task polls it on each packet.
|
/// Stores the type atomically; the recv task polls it on each packet.
|
||||||
pub fn on_network_changed(&self, network_type: u8, bandwidth_kbps: u32) {
|
pub fn on_network_changed(&self, network_type: u8, bandwidth_kbps: u32) {
|
||||||
info!(network_type, bandwidth_kbps, "on_network_changed");
|
info!(network_type, bandwidth_kbps, "on_network_changed");
|
||||||
self.state.pending_network_type.store(network_type, Ordering::Release);
|
self.state
|
||||||
|
.pending_network_type
|
||||||
|
.store(network_type, Ordering::Release);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_stats(&self) -> CallStats {
|
pub fn get_stats(&self) -> CallStats {
|
||||||
@@ -496,6 +523,7 @@ async fn run_call(
|
|||||||
let signature = kx.sign(&sign_data);
|
let signature = kx.sign(&sign_data);
|
||||||
|
|
||||||
let offer = SignalMessage::CallOffer {
|
let offer = SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature,
|
signature,
|
||||||
@@ -508,6 +536,9 @@ async fn run_call(
|
|||||||
QualityProfile::CATASTROPHIC,
|
QualityProfile::CATASTROPHIC,
|
||||||
],
|
],
|
||||||
alias: alias.map(|s| s.to_string()),
|
alias: alias.map(|s| s.to_string()),
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![CodecId::H264Baseline],
|
||||||
};
|
};
|
||||||
transport.send_signal(&offer).await?;
|
transport.send_signal(&offer).await?;
|
||||||
info!("CallOffer sent, waiting for CallAnswer...");
|
info!("CallOffer sent, waiting for CallAnswer...");
|
||||||
@@ -518,12 +549,16 @@ async fn run_call(
|
|||||||
.ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?;
|
.ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?;
|
||||||
|
|
||||||
let (relay_ephemeral_pub, chosen_profile) = match answer {
|
let (relay_ephemeral_pub, chosen_profile) = match answer {
|
||||||
SignalMessage::CallAnswer { ephemeral_pub, chosen_profile, .. } => (ephemeral_pub, chosen_profile),
|
SignalMessage::CallAnswer {
|
||||||
|
ephemeral_pub,
|
||||||
|
chosen_profile,
|
||||||
|
..
|
||||||
|
} => (ephemeral_pub, chosen_profile),
|
||||||
other => {
|
other => {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"expected CallAnswer, got {:?}",
|
"expected CallAnswer, got {:?}",
|
||||||
std::mem::discriminant(&other)
|
std::mem::discriminant(&other)
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -574,7 +609,7 @@ async fn run_call(
|
|||||||
stats.auto_mode = auto_profile;
|
stats.auto_mode = auto_profile;
|
||||||
}
|
}
|
||||||
|
|
||||||
let seq = AtomicU16::new(0);
|
let seq = AtomicU32::new(0);
|
||||||
let ts = AtomicU32::new(0);
|
let ts = AtomicU32::new(0);
|
||||||
let transport_recv = transport.clone();
|
let transport_recv = transport.clone();
|
||||||
|
|
||||||
@@ -700,17 +735,15 @@ async fn run_call(
|
|||||||
|
|
||||||
let source_pkt = MediaPacket {
|
let source_pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: MediaHeader::VERSION,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: current_profile.codec,
|
codec_id: current_profile.codec,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: hdr_fec_ratio,
|
fec_ratio: hdr_fec_ratio,
|
||||||
seq: s,
|
seq: s,
|
||||||
timestamp: t,
|
timestamp: t,
|
||||||
fec_block: hdr_fec_block,
|
fec_block: ((hdr_fec_symbol as u16) << 8) | (hdr_fec_block as u16),
|
||||||
fec_symbol: hdr_fec_symbol,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::copy_from_slice(encoded),
|
payload: Bytes::copy_from_slice(encoded),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -725,9 +758,7 @@ async fn run_call(
|
|||||||
if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 {
|
if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 {
|
||||||
warn!(
|
warn!(
|
||||||
seq = s,
|
seq = s,
|
||||||
send_errors,
|
send_errors, frames_dropped, "send_media error (dropping packet): {e}"
|
||||||
frames_dropped,
|
|
||||||
"send_media error (dropping packet): {e}"
|
|
||||||
);
|
);
|
||||||
last_send_error_log = Instant::now();
|
last_send_error_log = Instant::now();
|
||||||
}
|
}
|
||||||
@@ -756,19 +787,17 @@ async fn run_call(
|
|||||||
let rs = seq.fetch_add(1, Ordering::Relaxed);
|
let rs = seq.fetch_add(1, Ordering::Relaxed);
|
||||||
let repair_pkt = MediaPacket {
|
let repair_pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: MediaHeader::VERSION,
|
||||||
is_repair: true,
|
flags: MediaHeader::FLAG_REPAIR,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: current_profile.codec,
|
codec_id: current_profile.codec,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
fec_ratio: MediaHeader::encode_fec_ratio(
|
||||||
current_profile.fec_ratio,
|
current_profile.fec_ratio,
|
||||||
),
|
),
|
||||||
seq: rs,
|
seq: rs,
|
||||||
timestamp: t,
|
timestamp: t,
|
||||||
fec_block: block_id,
|
fec_block: (sym_idx << 8) | (block_id as u16),
|
||||||
fec_symbol: sym_idx,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(repair_data),
|
payload: Bytes::from(repair_data),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -820,7 +849,11 @@ async fn run_call(
|
|||||||
avg_total_us = avg(t_agc_us + t_opus_us + t_fec_us + t_send_us),
|
avg_total_us = avg(t_agc_us + t_opus_us + t_fec_us + t_send_us),
|
||||||
"send stats"
|
"send stats"
|
||||||
);
|
);
|
||||||
t_agc_us = 0; t_opus_us = 0; t_fec_us = 0; t_send_us = 0; t_frames = 0;
|
t_agc_us = 0;
|
||||||
|
t_opus_us = 0;
|
||||||
|
t_fec_us = 0;
|
||||||
|
t_send_us = 0;
|
||||||
|
t_frames = 0;
|
||||||
last_stats_log = Instant::now();
|
last_stats_log = Instant::now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -849,14 +882,11 @@ async fn run_call(
|
|||||||
// when a packet arrives with seq > expected_seq, the frames in
|
// when a packet arrives with seq > expected_seq, the frames in
|
||||||
// between are missing and we attempt to reconstruct them via
|
// between are missing and we attempt to reconstruct them via
|
||||||
// DRED before decoding the newly-arrived packet.
|
// DRED before decoding the newly-arrived packet.
|
||||||
let mut dred_decoder =
|
let mut dred_decoder = DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
|
||||||
DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
|
let mut dred_parse_scratch = DredState::new().expect("opus_dred_alloc failed (scratch)");
|
||||||
let mut dred_parse_scratch =
|
let mut last_good_dred = DredState::new().expect("opus_dred_alloc failed (good state)");
|
||||||
DredState::new().expect("opus_dred_alloc failed (scratch)");
|
let mut last_good_dred_seq: Option<u32> = None;
|
||||||
let mut last_good_dred =
|
let mut expected_seq: Option<u32> = None;
|
||||||
DredState::new().expect("opus_dred_alloc failed (good state)");
|
|
||||||
let mut last_good_dred_seq: Option<u16> = None;
|
|
||||||
let mut expected_seq: Option<u16> = None;
|
|
||||||
let mut dred_reconstructions: u64 = 0;
|
let mut dred_reconstructions: u64 = 0;
|
||||||
let mut classical_plc_invocations: u64 = 0;
|
let mut classical_plc_invocations: u64 = 0;
|
||||||
|
|
||||||
@@ -877,14 +907,16 @@ async fn run_call(
|
|||||||
warn!(
|
warn!(
|
||||||
recv_gap_ms,
|
recv_gap_ms,
|
||||||
seq = pkt.header.seq,
|
seq = pkt.header.seq,
|
||||||
is_repair = pkt.header.is_repair,
|
is_repair = pkt.header.is_repair(),
|
||||||
"large recv gap — possible network stall"
|
"large recv gap — possible network stall"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for network transport change from ConnectivityManager
|
// Check for network transport change from ConnectivityManager
|
||||||
{
|
{
|
||||||
let net = state.pending_network_type.swap(PROFILE_NO_CHANGE, Ordering::Acquire);
|
let net = state
|
||||||
|
.pending_network_type
|
||||||
|
.swap(PROFILE_NO_CHANGE, Ordering::Acquire);
|
||||||
if net != PROFILE_NO_CHANGE {
|
if net != PROFILE_NO_CHANGE {
|
||||||
use wzp_proto::NetworkContext;
|
use wzp_proto::NetworkContext;
|
||||||
let ctx = match net {
|
let ctx = match net {
|
||||||
@@ -916,9 +948,9 @@ async fn run_call(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
let pkt_block = pkt.header.fec_block;
|
let pkt_block = pkt.header.fec_block;
|
||||||
let pkt_symbol = pkt.header.fec_symbol;
|
let pkt_symbol = (pkt.header.fec_block >> 8) as u16;
|
||||||
let pkt_is_opus = pkt.header.codec_id.is_opus();
|
let pkt_is_opus = pkt.header.codec_id.is_opus();
|
||||||
|
|
||||||
// Phase 2: Opus packets bypass RaptorQ entirely — DRED
|
// Phase 2: Opus packets bypass RaptorQ entirely — DRED
|
||||||
@@ -927,12 +959,7 @@ async fn run_call(
|
|||||||
// would accumulate block_id=0 duplicates that never
|
// would accumulate block_id=0 duplicates that never
|
||||||
// decode. Codec2 packets still feed RaptorQ.
|
// decode. Codec2 packets still feed RaptorQ.
|
||||||
if !pkt_is_opus {
|
if !pkt_is_opus {
|
||||||
let _ = fec_dec.add_symbol(
|
let _ = fec_dec.add_symbol(pkt_block, pkt_symbol, is_repair, &pkt.payload);
|
||||||
pkt_block,
|
|
||||||
pkt_symbol,
|
|
||||||
is_repair,
|
|
||||||
&pkt.payload,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Source packets: decode directly
|
// Source packets: decode directly
|
||||||
@@ -951,8 +978,12 @@ async fn run_call(
|
|||||||
fec_ratio: 0.5,
|
fec_ratio: 0.5,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
..QualityProfile::GOOD
|
||||||
|
},
|
||||||
|
other => QualityProfile {
|
||||||
|
codec: other,
|
||||||
|
..QualityProfile::GOOD
|
||||||
},
|
},
|
||||||
other => QualityProfile { codec: other, ..QualityProfile::GOOD },
|
|
||||||
};
|
};
|
||||||
info!(from = ?decoder.codec_id(), to = ?pkt.header.codec_id, "recv: switching decoder");
|
info!(from = ?decoder.codec_id(), to = ?pkt.header.codec_id, "recv: switching decoder");
|
||||||
let _ = decoder.set_profile(switch_profile);
|
let _ = decoder.set_profile(switch_profile);
|
||||||
@@ -984,10 +1015,7 @@ async fn run_call(
|
|||||||
// Update DRED state from the current packet.
|
// Update DRED state from the current packet.
|
||||||
match dred_decoder.parse_into(&mut dred_parse_scratch, &pkt.payload) {
|
match dred_decoder.parse_into(&mut dred_parse_scratch, &pkt.payload) {
|
||||||
Ok(available) if available > 0 => {
|
Ok(available) if available > 0 => {
|
||||||
std::mem::swap(
|
std::mem::swap(&mut dred_parse_scratch, &mut last_good_dred);
|
||||||
&mut dred_parse_scratch,
|
|
||||||
&mut last_good_dred,
|
|
||||||
);
|
|
||||||
last_good_dred_seq = Some(pkt.header.seq);
|
last_good_dred_seq = Some(pkt.header.seq);
|
||||||
}
|
}
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
@@ -999,15 +1027,14 @@ async fn run_call(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Detect and fill gap from last-expected to this packet.
|
// Detect and fill gap from last-expected to this packet.
|
||||||
const MAX_GAP_FRAMES: u16 = 16;
|
const MAX_GAP_FRAMES: u32 = 16;
|
||||||
if let Some(expected) = expected_seq {
|
if let Some(expected) = expected_seq {
|
||||||
let gap = pkt.header.seq.wrapping_sub(expected);
|
let gap = pkt.header.seq.wrapping_sub(expected);
|
||||||
if gap > 0 && gap <= MAX_GAP_FRAMES {
|
if gap > 0 && gap <= MAX_GAP_FRAMES {
|
||||||
let current_profile_frame_samples =
|
let current_profile_frame_samples =
|
||||||
(48_000 * profile.frame_duration_ms as i32) / 1000;
|
(48_000 * profile.frame_duration_ms as i32) / 1000;
|
||||||
let available = last_good_dred.samples_available();
|
let available = last_good_dred.samples_available();
|
||||||
let pcm_slice_len =
|
let pcm_slice_len = current_profile_frame_samples as usize;
|
||||||
current_profile_frame_samples as usize;
|
|
||||||
|
|
||||||
for gap_idx in 0..gap {
|
for gap_idx in 0..gap {
|
||||||
let missing_seq = expected.wrapping_add(gap_idx);
|
let missing_seq = expected.wrapping_add(gap_idx);
|
||||||
@@ -1026,9 +1053,8 @@ async fn run_call(
|
|||||||
None => -1,
|
None => -1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let reconstructed = if offset_samples > 0
|
let reconstructed =
|
||||||
&& offset_samples <= available
|
if offset_samples > 0 && offset_samples <= available {
|
||||||
{
|
|
||||||
decoder
|
decoder
|
||||||
.reconstruct_from_dred(
|
.reconstruct_from_dred(
|
||||||
&last_good_dred,
|
&last_good_dred,
|
||||||
@@ -1042,12 +1068,9 @@ async fn run_call(
|
|||||||
|
|
||||||
match reconstructed {
|
match reconstructed {
|
||||||
Some(samples) => {
|
Some(samples) => {
|
||||||
playout_agc.process_frame(
|
playout_agc
|
||||||
&mut decode_buf[..samples],
|
.process_frame(&mut decode_buf[..samples]);
|
||||||
);
|
state.playout_ring.write(&decode_buf[..samples]);
|
||||||
state
|
|
||||||
.playout_ring
|
|
||||||
.write(&decode_buf[..samples]);
|
|
||||||
dred_reconstructions += 1;
|
dred_reconstructions += 1;
|
||||||
frames_decoded += 1;
|
frames_decoded += 1;
|
||||||
}
|
}
|
||||||
@@ -1144,7 +1167,10 @@ async fn run_call(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
info!(frames_decoded, fec_recovered, "relay disconnected (stream ended)");
|
info!(
|
||||||
|
frames_decoded,
|
||||||
|
fec_recovered, "relay disconnected (stream ended)"
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -1162,7 +1188,10 @@ async fn run_call(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info!(frames_decoded, fec_recovered, recv_errors, "recv task ended");
|
info!(
|
||||||
|
frames_decoded,
|
||||||
|
fec_recovered, recv_errors, "recv task ended"
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Stats task — polls path quality + quinn RTT every 500ms
|
// Stats task — polls path quality + quinn RTT every 500ms
|
||||||
@@ -1195,7 +1224,11 @@ async fn run_call(
|
|||||||
let signal_task = async {
|
let signal_task = async {
|
||||||
loop {
|
loop {
|
||||||
match transport_signal.recv_signal().await {
|
match transport_signal.recv_signal().await {
|
||||||
Ok(Some(SignalMessage::RoomUpdate { count, participants })) => {
|
Ok(Some(SignalMessage::RoomUpdate {
|
||||||
|
count,
|
||||||
|
participants,
|
||||||
|
..
|
||||||
|
})) => {
|
||||||
info!(count, "RoomUpdate received");
|
info!(count, "RoomUpdate received");
|
||||||
let members: Vec<crate::stats::RoomMember> = participants
|
let members: Vec<crate::stats::RoomMember> = participants
|
||||||
.iter()
|
.iter()
|
||||||
@@ -1209,6 +1242,19 @@ async fn run_call(
|
|||||||
stats.room_participant_count = count;
|
stats.room_participant_count = count;
|
||||||
stats.room_participants = members;
|
stats.room_participants = members;
|
||||||
}
|
}
|
||||||
|
Ok(Some(SignalMessage::QualityDirective {
|
||||||
|
recommended_profile,
|
||||||
|
reason,
|
||||||
|
..
|
||||||
|
})) => {
|
||||||
|
let idx = profile_to_index(&recommended_profile);
|
||||||
|
info!(
|
||||||
|
codec = ?recommended_profile.codec,
|
||||||
|
reason = reason.as_deref().unwrap_or(""),
|
||||||
|
"relay quality directive: switching profile"
|
||||||
|
);
|
||||||
|
pending_profile_recv.store(idx, Ordering::Release);
|
||||||
|
}
|
||||||
Ok(Some(msg)) => {
|
Ok(Some(msg)) => {
|
||||||
info!("signal received: {:?}", std::mem::discriminant(&msg));
|
info!("signal received: {:?}", std::mem::discriminant(&msg));
|
||||||
}
|
}
|
||||||
@@ -1238,7 +1284,9 @@ async fn run_call(
|
|||||||
match tokio::time::timeout(
|
match tokio::time::timeout(
|
||||||
std::time::Duration::from_millis(500),
|
std::time::Duration::from_millis(500),
|
||||||
transport.connection().closed(),
|
transport.connection().closed(),
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(_) => info!("QUIC connection closed cleanly"),
|
Ok(_) => info!("QUIC connection closed cleanly"),
|
||||||
Err(_) => info!("QUIC close timed out (relay may not have ack'd)"),
|
Err(_) => info!("QUIC close timed out (relay may not have ack'd)"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
use std::panic;
|
use std::panic;
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
|
|
||||||
|
use jni::JNIEnv;
|
||||||
use jni::objects::{JClass, JObject, JString};
|
use jni::objects::{JClass, JObject, JString};
|
||||||
use jni::sys::{jboolean, jint, jlong, jstring};
|
use jni::sys::{jboolean, jint, jlong, jstring};
|
||||||
use jni::JNIEnv;
|
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use wzp_proto::QualityProfile;
|
use wzp_proto::QualityProfile;
|
||||||
|
|
||||||
@@ -29,11 +29,13 @@ fn profile_from_int(value: jint) -> QualityProfile {
|
|||||||
0 => QualityProfile::GOOD, // Opus 24k
|
0 => QualityProfile::GOOD, // Opus 24k
|
||||||
1 => QualityProfile::DEGRADED, // Opus 6k
|
1 => QualityProfile::DEGRADED, // Opus 6k
|
||||||
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
|
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
|
||||||
3 => QualityProfile { // Codec2 3.2k
|
3 => QualityProfile {
|
||||||
|
// Codec2 3.2k
|
||||||
codec: wzp_proto::CodecId::Codec2_3200,
|
codec: wzp_proto::CodecId::Codec2_3200,
|
||||||
fec_ratio: 0.5,
|
fec_ratio: 0.5,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
..QualityProfile::GOOD
|
||||||
},
|
},
|
||||||
4 => QualityProfile::STUDIO_32K, // Opus 32k
|
4 => QualityProfile::STUDIO_32K, // Opus 32k
|
||||||
5 => QualityProfile::STUDIO_48K, // Opus 48k
|
5 => QualityProfile::STUDIO_48K, // Opus 48k
|
||||||
@@ -48,6 +50,8 @@ static INIT_LOGGING: Once = Once::new();
|
|||||||
/// Safe to call multiple times — only the first call takes effect.
|
/// Safe to call multiple times — only the first call takes effect.
|
||||||
fn init_logging() {
|
fn init_logging() {
|
||||||
INIT_LOGGING.call_once(|| {
|
INIT_LOGGING.call_once(|| {
|
||||||
|
#[cfg(target_os = "android")]
|
||||||
|
{
|
||||||
// Wrap in catch_unwind — sharded_slab allocation inside
|
// Wrap in catch_unwind — sharded_slab allocation inside
|
||||||
// tracing_subscriber::registry() can crash on some Android
|
// tracing_subscriber::registry() can crash on some Android
|
||||||
// devices if scudo malloc fails during early initialization.
|
// devices if scudo malloc fails during early initialization.
|
||||||
@@ -67,6 +71,12 @@ fn init_logging() {
|
|||||||
.try_init();
|
.try_init();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
#[cfg(not(target_os = "android"))]
|
||||||
|
{
|
||||||
|
// On non-Android targets tracing-android is unavailable.
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,11 +111,26 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
|||||||
profile_j: jint,
|
profile_j: jint,
|
||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
let relay_addr: String = env
|
||||||
let room: String = env.get_string(&room_j).map(|s| s.into()).unwrap_or_default();
|
.get_string(&relay_addr_j)
|
||||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
.map(|s| s.into())
|
||||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default();
|
let room: String = env
|
||||||
|
.get_string(&room_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let seed_hex: String = env
|
||||||
|
.get_string(&seed_hex_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let token: String = env
|
||||||
|
.get_string(&token_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let alias: String = env
|
||||||
|
.get_string(&alias_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
|
|
||||||
@@ -128,7 +153,11 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
|||||||
auto_profile: profile_j == PROFILE_AUTO,
|
auto_profile: profile_j == PROFILE_AUTO,
|
||||||
relay_addr,
|
relay_addr,
|
||||||
room,
|
room,
|
||||||
auth_token: if token.is_empty() { Vec::new() } else { token.into_bytes() },
|
auth_token: if token.is_empty() {
|
||||||
|
Vec::new()
|
||||||
|
} else {
|
||||||
|
token.into_bytes()
|
||||||
|
},
|
||||||
identity_seed,
|
identity_seed,
|
||||||
alias: if alias.is_empty() { None } else { Some(alias) },
|
alias: if alias.is_empty() { None } else { Some(alias) },
|
||||||
};
|
};
|
||||||
@@ -241,7 +270,8 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeOnNetworkChang
|
|||||||
) {
|
) {
|
||||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
h.engine.on_network_changed(network_type as u8, bandwidth_kbps as u32);
|
h.engine
|
||||||
|
.on_network_changed(network_type as u8, bandwidth_kbps as u32);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,13 +337,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudioDire
|
|||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut());
|
let ptr = env
|
||||||
|
.get_direct_buffer_address(&buffer)
|
||||||
|
.unwrap_or(std::ptr::null_mut());
|
||||||
if ptr.is_null() || sample_count <= 0 {
|
if ptr.is_null() || sample_count <= 0 {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
let samples = unsafe {
|
let samples =
|
||||||
std::slice::from_raw_parts(ptr as *const i16, sample_count as usize)
|
unsafe { std::slice::from_raw_parts(ptr as *const i16, sample_count as usize) };
|
||||||
};
|
|
||||||
h.engine.write_audio(samples) as jint
|
h.engine.write_audio(samples) as jint
|
||||||
}));
|
}));
|
||||||
result.unwrap_or(0)
|
result.unwrap_or(0)
|
||||||
@@ -332,13 +363,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudioDirec
|
|||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let ptr = env.get_direct_buffer_address(&buffer).unwrap_or(std::ptr::null_mut());
|
let ptr = env
|
||||||
|
.get_direct_buffer_address(&buffer)
|
||||||
|
.unwrap_or(std::ptr::null_mut());
|
||||||
if ptr.is_null() || max_samples <= 0 {
|
if ptr.is_null() || max_samples <= 0 {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
let samples = unsafe {
|
let samples =
|
||||||
std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize)
|
unsafe { std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize) };
|
||||||
};
|
|
||||||
h.engine.read_audio(samples) as jint
|
h.engine.read_audio(samples) as jint
|
||||||
}));
|
}));
|
||||||
result.unwrap_or(0)
|
result.unwrap_or(0)
|
||||||
@@ -367,7 +399,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePingRelay<'a>(
|
|||||||
) -> jstring {
|
) -> jstring {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let relay: String = env.get_string(&relay_j).map(|s| s.into()).unwrap_or_default();
|
let relay: String = env
|
||||||
|
.get_string(&relay_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
match h.engine.ping_relay(&relay) {
|
match h.engine.ping_relay(&relay) {
|
||||||
Ok(json) => Some(json),
|
Ok(json) => Some(json),
|
||||||
Err(_) => None,
|
Err(_) => None,
|
||||||
@@ -399,10 +434,22 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
|
|||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
let relay_addr: String = env
|
||||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
.get_string(&relay_addr_j)
|
||||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
.map(|s| s.into())
|
||||||
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
let seed_hex: String = env
|
||||||
|
.get_string(&seed_hex_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let token: String = env
|
||||||
|
.get_string(&token_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let alias: String = env
|
||||||
|
.get_string(&alias_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
h.engine.start_signaling(
|
h.engine.start_signaling(
|
||||||
&relay_addr,
|
&relay_addr,
|
||||||
@@ -414,8 +461,14 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => 0,
|
Ok(Ok(())) => 0,
|
||||||
Ok(Err(e)) => { error!("start_signaling failed: {e}"); -1 }
|
Ok(Err(e)) => {
|
||||||
Err(_) => { error!("start_signaling panicked"); -1 }
|
error!("start_signaling failed: {e}");
|
||||||
|
-1
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
error!("start_signaling panicked");
|
||||||
|
-1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,14 +483,23 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePlaceCall<'a>(
|
|||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let target: String = env.get_string(&target_fp_j).map(|s| s.into()).unwrap_or_default();
|
let target: String = env
|
||||||
|
.get_string(&target_fp_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
h.engine.place_call(&target)
|
h.engine.place_call(&target)
|
||||||
}));
|
}));
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => 0,
|
Ok(Ok(())) => 0,
|
||||||
Ok(Err(e)) => { error!("place_call failed: {e}"); -1 }
|
Ok(Err(e)) => {
|
||||||
Err(_) => { error!("place_call panicked"); -1 }
|
error!("place_call failed: {e}");
|
||||||
|
-1
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
error!("place_call panicked");
|
||||||
|
-1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,7 +515,10 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
|
|||||||
) -> jint {
|
) -> jint {
|
||||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||||
let h = unsafe { handle_ref(handle) };
|
let h = unsafe { handle_ref(handle) };
|
||||||
let call_id: String = env.get_string(&call_id_j).map(|s| s.into()).unwrap_or_default();
|
let call_id: String = env
|
||||||
|
.get_string(&call_id_j)
|
||||||
|
.map(|s| s.into())
|
||||||
|
.unwrap_or_default();
|
||||||
let accept_mode = match mode {
|
let accept_mode = match mode {
|
||||||
0 => wzp_proto::CallAcceptMode::Reject,
|
0 => wzp_proto::CallAcceptMode::Reject,
|
||||||
1 => wzp_proto::CallAcceptMode::AcceptTrusted,
|
1 => wzp_proto::CallAcceptMode::AcceptTrusted,
|
||||||
@@ -464,7 +529,13 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => 0,
|
Ok(Ok(())) => 0,
|
||||||
Ok(Err(e)) => { error!("answer_call failed: {e}"); -1 }
|
Ok(Err(e)) => {
|
||||||
Err(_) => { error!("answer_call panicked"); -1 }
|
error!("answer_call failed: {e}");
|
||||||
|
-1
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
error!("answer_call panicked");
|
||||||
|
-1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,6 @@ pub mod audio_android;
|
|||||||
pub mod audio_ring;
|
pub mod audio_ring;
|
||||||
pub mod commands;
|
pub mod commands;
|
||||||
pub mod engine;
|
pub mod engine;
|
||||||
|
pub mod jni_bridge;
|
||||||
pub mod pipeline;
|
pub mod pipeline;
|
||||||
pub mod stats;
|
pub mod stats;
|
||||||
pub mod jni_bridge;
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder, AutoGainControl, EchoCanceller
|
|||||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||||
use wzp_proto::quality::AdaptiveQualityController;
|
use wzp_proto::quality::AdaptiveQualityController;
|
||||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
|
||||||
use wzp_proto::traits::QualityController;
|
use wzp_proto::traits::QualityController;
|
||||||
|
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||||
use wzp_proto::{MediaPacket, QualityProfile};
|
use wzp_proto::{MediaPacket, QualityProfile};
|
||||||
|
|
||||||
use crate::audio_android::FRAME_SAMPLES;
|
use crate::audio_android::FRAME_SAMPLES;
|
||||||
@@ -58,14 +58,12 @@ pub struct Pipeline {
|
|||||||
impl Pipeline {
|
impl Pipeline {
|
||||||
/// Create a new pipeline configured for the given quality profile.
|
/// Create a new pipeline configured for the given quality profile.
|
||||||
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
|
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
|
||||||
let encoder = AdaptiveEncoder::new(profile)
|
let encoder =
|
||||||
.map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
AdaptiveEncoder::new(profile).map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||||
let decoder = AdaptiveDecoder::new(profile)
|
let decoder =
|
||||||
.map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
AdaptiveDecoder::new(profile).map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||||
let fec_encoder =
|
let fec_encoder = RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||||
RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
let fec_decoder = RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||||
let fec_decoder =
|
|
||||||
RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
|
||||||
let jitter_buffer = JitterBuffer::new(10, 250, 3);
|
let jitter_buffer = JitterBuffer::new(10, 250, 3);
|
||||||
let quality_ctrl = AdaptiveQualityController::new();
|
let quality_ctrl = AdaptiveQualityController::new();
|
||||||
|
|
||||||
@@ -136,11 +134,11 @@ impl Pipeline {
|
|||||||
pub fn feed_packet(&mut self, packet: MediaPacket) {
|
pub fn feed_packet(&mut self, packet: MediaPacket) {
|
||||||
// Feed FEC symbols if present
|
// Feed FEC symbols if present
|
||||||
let header = &packet.header;
|
let header = &packet.header;
|
||||||
if header.fec_block != 0 || header.fec_symbol != 0 {
|
if header.fec_block != 0 {
|
||||||
let is_repair = header.is_repair;
|
let is_repair = header.is_repair();
|
||||||
if let Err(e) = self.fec_decoder.add_symbol(
|
if let Err(e) = self.fec_decoder.add_symbol(
|
||||||
header.fec_block,
|
header.fec_block,
|
||||||
header.fec_symbol,
|
header.fec_block >> 8,
|
||||||
is_repair,
|
is_repair,
|
||||||
&packet.payload,
|
&packet.payload,
|
||||||
) {
|
) {
|
||||||
@@ -211,10 +209,7 @@ impl Pipeline {
|
|||||||
///
|
///
|
||||||
/// Returns a new profile if a tier transition occurred.
|
/// Returns a new profile if a tier transition occurred.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn observe_quality(
|
pub fn observe_quality(&mut self, report: &wzp_proto::QualityReport) -> Option<QualityProfile> {
|
||||||
&mut self,
|
|
||||||
report: &wzp_proto::QualityReport,
|
|
||||||
) -> Option<QualityProfile> {
|
|
||||||
let new_profile = self.quality_ctrl.observe(report);
|
let new_profile = self.quality_ctrl.observe(report);
|
||||||
if let Some(ref profile) = new_profile {
|
if let Some(ref profile) = new_profile {
|
||||||
if let Err(e) = self.encoder.set_profile(*profile) {
|
if let Err(e) = self.encoder.set_profile(*profile) {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ wzp-codec = { workspace = true }
|
|||||||
wzp-fec = { workspace = true }
|
wzp-fec = { workspace = true }
|
||||||
wzp-crypto = { workspace = true }
|
wzp-crypto = { workspace = true }
|
||||||
wzp-transport = { workspace = true }
|
wzp-transport = { workspace = true }
|
||||||
|
wzp-video = { path = "../wzp-video" }
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
@@ -21,6 +22,9 @@ anyhow = "1"
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
ratatui = "0.29"
|
||||||
|
crossterm = "0.28"
|
||||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||||
cpal = { version = "0.15", optional = true }
|
cpal = { version = "0.15", optional = true }
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
@@ -30,6 +34,8 @@ libc = "0.2"
|
|||||||
# through the WAN reflex addr (which many consumer NATs, including
|
# through the WAN reflex addr (which many consumer NATs, including
|
||||||
# MikroTik's default masquerade, don't support).
|
# MikroTik's default masquerade, don't support).
|
||||||
if-addrs = "0.13"
|
if-addrs = "0.13"
|
||||||
|
rand = { workspace = true }
|
||||||
|
socket2 = "0.5"
|
||||||
|
|
||||||
# coreaudio-rs is Apple-framework-only; gate it to macOS so enabling
|
# coreaudio-rs is Apple-framework-only; gate it to macOS so enabling
|
||||||
# the `vpio` feature from a non-macOS target builds cleanly instead of
|
# the `vpio` feature from a non-macOS target builds cleanly instead of
|
||||||
@@ -99,6 +105,10 @@ linux-aec = ["dep:webrtc-audio-processing"]
|
|||||||
name = "wzp-client"
|
name = "wzp-client"
|
||||||
path = "src/cli.rs"
|
path = "src/cli.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "wzp-analyzer"
|
||||||
|
path = "src/analyzer.rs"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "wzp-bench"
|
name = "wzp-bench"
|
||||||
path = "src/bench_cli.rs"
|
path = "src/bench_cli.rs"
|
||||||
|
|||||||
1309
crates/wzp-client/src/analyzer.rs
Normal file
1309
crates/wzp-client/src/analyzer.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,10 +6,10 @@
|
|||||||
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
|
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
|
||||||
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
|
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{Context, anyhow};
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
@@ -78,7 +78,10 @@ impl AudioCapture {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if !logged.swap(true, Ordering::Relaxed) {
|
if !logged.swap(true, Ordering::Relaxed) {
|
||||||
eprintln!("[audio] capture callback: {} f32 samples", data.len());
|
eprintln!(
|
||||||
|
"[audio] capture callback: {} f32 samples",
|
||||||
|
data.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||||
for chunk in data.chunks(FRAME_SAMPLES) {
|
for chunk in data.chunks(FRAME_SAMPLES) {
|
||||||
@@ -103,7 +106,10 @@ impl AudioCapture {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if !logged.swap(true, Ordering::Relaxed) {
|
if !logged.swap(true, Ordering::Relaxed) {
|
||||||
eprintln!("[audio] capture callback: {} i16 samples", data.len());
|
eprintln!(
|
||||||
|
"[audio] capture callback: {} i16 samples",
|
||||||
|
data.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
ring.write(data);
|
ring.write(data);
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -54,13 +54,13 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::{Arc, Mutex, OnceLock};
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{Context, anyhow};
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use webrtc_audio_processing::{
|
use webrtc_audio_processing::{
|
||||||
Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig,
|
Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig,
|
||||||
NoiseSuppression, NoiseSuppressionLevel, Processor, NUM_SAMPLES_PER_FRAME,
|
NUM_SAMPLES_PER_FRAME, NoiseSuppression, NoiseSuppressionLevel, Processor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::audio_ring::AudioRing;
|
use crate::audio_ring::AudioRing;
|
||||||
@@ -97,8 +97,8 @@ fn get_or_init_processor() -> anyhow::Result<Arc<Mutex<Processor>>> {
|
|||||||
num_render_channels: APM_NUM_CHANNELS as i32,
|
num_render_channels: APM_NUM_CHANNELS as i32,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut processor = Processor::new(&init_config)
|
let mut processor =
|
||||||
.map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
|
Processor::new(&init_config).map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
|
||||||
|
|
||||||
let config = Config {
|
let config = Config {
|
||||||
echo_cancellation: Some(EchoCancellation {
|
echo_cancellation: Some(EchoCancellation {
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
//! to the speaker, so it can cancel the echo from the mic signal internally.
|
//! to the speaker, so it can cancel the echo from the mic signal internally.
|
||||||
//! This is the same engine FaceTime and other Apple apps use.
|
//! This is the same engine FaceTime and other Apple apps use.
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use coreaudio::audio_unit::audio_format::LinearPcmFlags;
|
use coreaudio::audio_unit::audio_format::LinearPcmFlags;
|
||||||
@@ -28,6 +28,60 @@ pub struct VpioAudio {
|
|||||||
playout_ring: Arc<AudioRing>,
|
playout_ring: Arc<AudioRing>,
|
||||||
_audio_unit: AudioUnit,
|
_audio_unit: AudioUnit,
|
||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
|
stats: Arc<VpioStats>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render/capture counters for diagnosing macOS VoiceProcessingIO.
|
||||||
|
///
|
||||||
|
/// These are atomics because CoreAudio callbacks run on realtime audio
|
||||||
|
/// threads. The Tauri engine polls snapshots from a normal async task and
|
||||||
|
/// emits them to the call debug log.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct VpioStats {
|
||||||
|
capture_callbacks: AtomicU64,
|
||||||
|
capture_samples: AtomicU64,
|
||||||
|
render_callbacks: AtomicU64,
|
||||||
|
render_requested_samples: AtomicU64,
|
||||||
|
render_read_samples: AtomicU64,
|
||||||
|
render_underrun_callbacks: AtomicU64,
|
||||||
|
render_nonzero_callbacks: AtomicU64,
|
||||||
|
render_last_requested: AtomicU64,
|
||||||
|
render_last_read: AtomicU64,
|
||||||
|
render_last_rms: AtomicU64,
|
||||||
|
render_last_ring_available: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct VpioStatsSnapshot {
|
||||||
|
pub capture_callbacks: u64,
|
||||||
|
pub capture_samples: u64,
|
||||||
|
pub render_callbacks: u64,
|
||||||
|
pub render_requested_samples: u64,
|
||||||
|
pub render_read_samples: u64,
|
||||||
|
pub render_underrun_callbacks: u64,
|
||||||
|
pub render_nonzero_callbacks: u64,
|
||||||
|
pub render_last_requested: u64,
|
||||||
|
pub render_last_read: u64,
|
||||||
|
pub render_last_rms: u64,
|
||||||
|
pub render_last_ring_available: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VpioStats {
|
||||||
|
pub fn snapshot(&self) -> VpioStatsSnapshot {
|
||||||
|
VpioStatsSnapshot {
|
||||||
|
capture_callbacks: self.capture_callbacks.load(Ordering::Relaxed),
|
||||||
|
capture_samples: self.capture_samples.load(Ordering::Relaxed),
|
||||||
|
render_callbacks: self.render_callbacks.load(Ordering::Relaxed),
|
||||||
|
render_requested_samples: self.render_requested_samples.load(Ordering::Relaxed),
|
||||||
|
render_read_samples: self.render_read_samples.load(Ordering::Relaxed),
|
||||||
|
render_underrun_callbacks: self.render_underrun_callbacks.load(Ordering::Relaxed),
|
||||||
|
render_nonzero_callbacks: self.render_nonzero_callbacks.load(Ordering::Relaxed),
|
||||||
|
render_last_requested: self.render_last_requested.load(Ordering::Relaxed),
|
||||||
|
render_last_read: self.render_last_read.load(Ordering::Relaxed),
|
||||||
|
render_last_rms: self.render_last_rms.load(Ordering::Relaxed),
|
||||||
|
render_last_ring_available: self.render_last_ring_available.load(Ordering::Relaxed),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VpioAudio {
|
impl VpioAudio {
|
||||||
@@ -36,6 +90,7 @@ impl VpioAudio {
|
|||||||
let capture_ring = Arc::new(AudioRing::new());
|
let capture_ring = Arc::new(AudioRing::new());
|
||||||
let playout_ring = Arc::new(AudioRing::new());
|
let playout_ring = Arc::new(AudioRing::new());
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
|
let stats = Arc::new(VpioStats::default());
|
||||||
|
|
||||||
let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
|
let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
|
||||||
.context("failed to create VoiceProcessingIO audio unit")?;
|
.context("failed to create VoiceProcessingIO audio unit")?;
|
||||||
@@ -98,6 +153,7 @@ impl VpioAudio {
|
|||||||
// Set up input callback (mic capture with AEC applied)
|
// Set up input callback (mic capture with AEC applied)
|
||||||
let cap_ring = capture_ring.clone();
|
let cap_ring = capture_ring.clone();
|
||||||
let cap_running = running.clone();
|
let cap_running = running.clone();
|
||||||
|
let cap_stats = stats.clone();
|
||||||
let logged = Arc::new(AtomicBool::new(false));
|
let logged = Arc::new(AtomicBool::new(false));
|
||||||
au.set_input_callback(
|
au.set_input_callback(
|
||||||
move |args: render_callback::Args<data::NonInterleaved<f32>>| {
|
move |args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||||
@@ -106,6 +162,10 @@ impl VpioAudio {
|
|||||||
}
|
}
|
||||||
let mut buffers = args.data.channels();
|
let mut buffers = args.data.channels();
|
||||||
if let Some(ch) = buffers.next() {
|
if let Some(ch) = buffers.next() {
|
||||||
|
cap_stats.capture_callbacks.fetch_add(1, Ordering::Relaxed);
|
||||||
|
cap_stats
|
||||||
|
.capture_samples
|
||||||
|
.fetch_add(ch.len() as u64, Ordering::Relaxed);
|
||||||
if !logged.swap(true, Ordering::Relaxed) {
|
if !logged.swap(true, Ordering::Relaxed) {
|
||||||
eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
|
eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
|
||||||
}
|
}
|
||||||
@@ -125,28 +185,80 @@ impl VpioAudio {
|
|||||||
|
|
||||||
// Set up output callback (speaker playback — AEC uses this as reference)
|
// Set up output callback (speaker playback — AEC uses this as reference)
|
||||||
let play_ring = playout_ring.clone();
|
let play_ring = playout_ring.clone();
|
||||||
|
let render_stats = stats.clone();
|
||||||
|
let logged_render = Arc::new(AtomicBool::new(false));
|
||||||
au.set_render_callback(
|
au.set_render_callback(
|
||||||
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
|
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||||
let mut buffers = args.data.channels_mut();
|
let mut buffers = args.data.channels_mut();
|
||||||
if let Some(ch) = buffers.next() {
|
if let Some(ch) = buffers.next() {
|
||||||
|
render_stats
|
||||||
|
.render_callbacks
|
||||||
|
.fetch_add(1, Ordering::Relaxed);
|
||||||
|
render_stats
|
||||||
|
.render_requested_samples
|
||||||
|
.fetch_add(ch.len() as u64, Ordering::Relaxed);
|
||||||
|
render_stats
|
||||||
|
.render_last_requested
|
||||||
|
.store(ch.len() as u64, Ordering::Relaxed);
|
||||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||||
|
let mut total_read = 0usize;
|
||||||
|
let mut sum_sq = 0u64;
|
||||||
|
let ring_available = play_ring.available();
|
||||||
for chunk in ch.chunks_mut(FRAME_SAMPLES) {
|
for chunk in ch.chunks_mut(FRAME_SAMPLES) {
|
||||||
let n = chunk.len();
|
let n = chunk.len();
|
||||||
let read = play_ring.read(&mut tmp[..n]);
|
let read = play_ring.read(&mut tmp[..n]);
|
||||||
|
total_read += read;
|
||||||
for i in 0..read {
|
for i in 0..read {
|
||||||
|
let s = tmp[i] as i64;
|
||||||
|
sum_sq = sum_sq.saturating_add((s * s) as u64);
|
||||||
chunk[i] = tmp[i] as f32 / i16::MAX as f32;
|
chunk[i] = tmp[i] as f32 / i16::MAX as f32;
|
||||||
}
|
}
|
||||||
for i in read..n {
|
for i in read..n {
|
||||||
chunk[i] = 0.0;
|
chunk[i] = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
render_stats
|
||||||
|
.render_read_samples
|
||||||
|
.fetch_add(total_read as u64, Ordering::Relaxed);
|
||||||
|
render_stats
|
||||||
|
.render_last_read
|
||||||
|
.store(total_read as u64, Ordering::Relaxed);
|
||||||
|
render_stats
|
||||||
|
.render_last_ring_available
|
||||||
|
.store(ring_available as u64, Ordering::Relaxed);
|
||||||
|
if total_read == 0 {
|
||||||
|
render_stats
|
||||||
|
.render_underrun_callbacks
|
||||||
|
.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
let rms = if total_read > 0 {
|
||||||
|
((sum_sq as f64 / total_read as f64).sqrt()) as u64
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
render_stats.render_last_rms.store(rms, Ordering::Relaxed);
|
||||||
|
if rms > 0 {
|
||||||
|
render_stats
|
||||||
|
.render_nonzero_callbacks
|
||||||
|
.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
if !logged_render.swap(true, Ordering::Relaxed) {
|
||||||
|
eprintln!(
|
||||||
|
"[vpio] render callback: {} f32 samples, ring_available={}, ring_read={}, rms={}",
|
||||||
|
ch.len(),
|
||||||
|
ring_available,
|
||||||
|
total_read,
|
||||||
|
rms
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.context("failed to set render callback")?;
|
.context("failed to set render callback")?;
|
||||||
|
|
||||||
au.initialize().context("failed to initialize VoiceProcessingIO")?;
|
au.initialize()
|
||||||
|
.context("failed to initialize VoiceProcessingIO")?;
|
||||||
au.start().context("failed to start VoiceProcessingIO")?;
|
au.start().context("failed to start VoiceProcessingIO")?;
|
||||||
|
|
||||||
info!("VoiceProcessingIO started (OS-level AEC enabled)");
|
info!("VoiceProcessingIO started (OS-level AEC enabled)");
|
||||||
@@ -156,6 +268,7 @@ impl VpioAudio {
|
|||||||
playout_ring,
|
playout_ring,
|
||||||
_audio_unit: au,
|
_audio_unit: au,
|
||||||
running,
|
running,
|
||||||
|
stats,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +280,10 @@ impl VpioAudio {
|
|||||||
&self.playout_ring
|
&self.playout_ring
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stats(&self) -> Arc<VpioStats> {
|
||||||
|
self.stats.clone()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn stop(&self) {
|
pub fn stop(&self) {
|
||||||
self.running.store(false, Ordering::Relaxed);
|
self.running.store(false, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,24 +15,24 @@
|
|||||||
//! `wzp-client`'s lib.rs can transparently re-export either one as
|
//! `wzp-client`'s lib.rs can transparently re-export either one as
|
||||||
//! `AudioCapture`.
|
//! `AudioCapture`.
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{Context, anyhow};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use windows::core::{Interface, GUID};
|
use windows::Win32::Foundation::{BOOL, CloseHandle, WAIT_OBJECT_0};
|
||||||
use windows::Win32::Foundation::{CloseHandle, BOOL, WAIT_OBJECT_0};
|
|
||||||
use windows::Win32::Media::Audio::{
|
use windows::Win32::Media::Audio::{
|
||||||
eCapture, eCommunications, AudioCategory_Communications, AudioClientProperties,
|
|
||||||
IAudioCaptureClient, IAudioClient, IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator,
|
|
||||||
AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM,
|
AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM,
|
||||||
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY, WAVEFORMATEX,
|
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY,
|
||||||
WAVE_FORMAT_PCM,
|
AudioCategory_Communications, AudioClientProperties, IAudioCaptureClient, IAudioClient,
|
||||||
|
IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator, WAVE_FORMAT_PCM, WAVEFORMATEX,
|
||||||
|
eCapture, eCommunications,
|
||||||
};
|
};
|
||||||
use windows::Win32::System::Com::{
|
use windows::Win32::System::Com::{
|
||||||
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED,
|
CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
|
||||||
};
|
};
|
||||||
use windows::Win32::System::Threading::{CreateEventW, WaitForSingleObject, INFINITE};
|
use windows::Win32::System::Threading::{CreateEventW, INFINITE, WaitForSingleObject};
|
||||||
|
use windows::core::{GUID, Interface};
|
||||||
|
|
||||||
use crate::audio_ring::AudioRing;
|
use crate::audio_ring::AudioRing;
|
||||||
|
|
||||||
@@ -138,8 +138,7 @@ unsafe fn capture_thread_main(
|
|||||||
}
|
}
|
||||||
let _com_guard = ComGuard;
|
let _com_guard = ComGuard;
|
||||||
|
|
||||||
let enumerator: IMMDeviceEnumerator =
|
let enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
|
||||||
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
|
|
||||||
.context("CoCreateInstance(MMDeviceEnumerator) failed")?;
|
.context("CoCreateInstance(MMDeviceEnumerator) failed")?;
|
||||||
|
|
||||||
// eCommunications role (not eConsole) — this picks the device the user
|
// eCommunications role (not eConsole) — this picks the device the user
|
||||||
@@ -206,12 +205,13 @@ unsafe fn capture_thread_main(
|
|||||||
&wave_format,
|
&wave_format,
|
||||||
Some(&GUID::zeroed()),
|
Some(&GUID::zeroed()),
|
||||||
)
|
)
|
||||||
.context("IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16")?;
|
.context(
|
||||||
|
"IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16",
|
||||||
|
)?;
|
||||||
|
|
||||||
// Event-driven capture: Windows signals this handle each time a new
|
// Event-driven capture: Windows signals this handle each time a new
|
||||||
// audio packet is available. We wait on it from the loop below.
|
// audio packet is available. We wait on it from the loop below.
|
||||||
let event = CreateEventW(None, false, false, None)
|
let event = CreateEventW(None, false, false, None).context("CreateEventW failed")?;
|
||||||
.context("CreateEventW failed")?;
|
|
||||||
audio_client
|
audio_client
|
||||||
.SetEventHandle(event)
|
.SetEventHandle(event)
|
||||||
.context("SetEventHandle failed")?;
|
.context("SetEventHandle failed")?;
|
||||||
@@ -285,10 +285,8 @@ unsafe fn capture_thread_main(
|
|||||||
// Because we asked for 48 kHz mono i16, each frame is
|
// Because we asked for 48 kHz mono i16, each frame is
|
||||||
// exactly one i16. Windows's AUTOCONVERTPCM handles the
|
// exactly one i16. Windows's AUTOCONVERTPCM handles the
|
||||||
// conversion from whatever the engine mix format is.
|
// conversion from whatever the engine mix format is.
|
||||||
let samples = std::slice::from_raw_parts(
|
let samples =
|
||||||
buffer_ptr as *const i16,
|
std::slice::from_raw_parts(buffer_ptr as *const i16, num_frames as usize);
|
||||||
num_frames as usize,
|
|
||||||
);
|
|
||||||
ring.write(samples);
|
ring.write(samples);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ use std::time::{Duration, Instant};
|
|||||||
|
|
||||||
use wzp_crypto::ChaChaSession;
|
use wzp_crypto::ChaChaSession;
|
||||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||||
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
|
||||||
use wzp_proto::QualityProfile;
|
use wzp_proto::QualityProfile;
|
||||||
|
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
||||||
|
|
||||||
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
|||||||
let mut total_repair_bytes = 0usize;
|
let mut total_repair_bytes = 0usize;
|
||||||
|
|
||||||
for block_idx in 0..num_blocks {
|
for block_idx in 0..num_blocks {
|
||||||
let block_id = (block_idx % 256) as u8;
|
let block_id = (block_idx % 65536) as u16;
|
||||||
|
|
||||||
// Create fresh encoder and decoder for each block
|
// Create fresh encoder and decoder for each block
|
||||||
let mut fec_enc = RaptorQFecEncoder::new(frames_per_block, 256);
|
let mut fec_enc = RaptorQFecEncoder::new(frames_per_block, 256);
|
||||||
@@ -170,7 +170,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
|||||||
|
|
||||||
// Collect all symbols: source + repair
|
// Collect all symbols: source + repair
|
||||||
struct Symbol {
|
struct Symbol {
|
||||||
index: u8,
|
index: u16,
|
||||||
is_repair: bool,
|
is_repair: bool,
|
||||||
data: Vec<u8>,
|
data: Vec<u8>,
|
||||||
}
|
}
|
||||||
@@ -180,7 +180,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
|||||||
// For add_symbol we need to provide the raw data; the decoder pads internally
|
// For add_symbol we need to provide the raw data; the decoder pads internally
|
||||||
total_source_bytes += sym.len();
|
total_source_bytes += sym.len();
|
||||||
all_symbols.push(Symbol {
|
all_symbols.push(Symbol {
|
||||||
index: i as u8,
|
index: i as u16,
|
||||||
is_repair: false,
|
is_repair: false,
|
||||||
data: sym.clone(),
|
data: sym.clone(),
|
||||||
});
|
});
|
||||||
@@ -201,9 +201,13 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
|||||||
// Deterministic shuffle for reproducibility using a simple seed
|
// Deterministic shuffle for reproducibility using a simple seed
|
||||||
// We use a basic Fisher-Yates with a fixed-per-block seed
|
// We use a basic Fisher-Yates with a fixed-per-block seed
|
||||||
let mut indices: Vec<usize> = (0..all_symbols.len()).collect();
|
let mut indices: Vec<usize> = (0..all_symbols.len()).collect();
|
||||||
let mut seed = (block_idx as u64).wrapping_mul(6364136223846793005).wrapping_add(1);
|
let mut seed = (block_idx as u64)
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1);
|
||||||
for i in (1..indices.len()).rev() {
|
for i in (1..indices.len()).rev() {
|
||||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
seed = seed
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
let j = (seed >> 33) as usize % (i + 1);
|
let j = (seed >> 33) as usize % (i + 1);
|
||||||
indices.swap(i, j);
|
indices.swap(i, j);
|
||||||
}
|
}
|
||||||
@@ -259,17 +263,36 @@ pub fn bench_encrypt_decrypt() -> CryptoResult {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let header = b"bench-header";
|
// Build valid v2 MediaHeader bytes — encrypt/decrypt now derive nonces from
|
||||||
|
// header.seq and require a parseable MediaHeader (WIRE_SIZE bytes minimum).
|
||||||
|
use wzp_proto::packet::MediaHeader;
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
let mut total_bytes: usize = 0;
|
let mut total_bytes: usize = 0;
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for payload in &payloads {
|
for (i, payload) in payloads.iter().enumerate() {
|
||||||
|
let hdr = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq: i as u32,
|
||||||
|
timestamp: (i as u32).wrapping_mul(20),
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||||
|
hdr.write_to(&mut header_bytes);
|
||||||
|
|
||||||
let mut ciphertext = Vec::with_capacity(payload.len() + 16);
|
let mut ciphertext = Vec::with_capacity(payload.len() + 16);
|
||||||
encryptor.encrypt(header, payload, &mut ciphertext).unwrap();
|
encryptor
|
||||||
|
.encrypt(&header_bytes, payload, &mut ciphertext)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut plaintext = Vec::with_capacity(payload.len());
|
let mut plaintext = Vec::with_capacity(payload.len());
|
||||||
decryptor
|
decryptor
|
||||||
.decrypt(header, &ciphertext, &mut plaintext)
|
.decrypt(&header_bytes, &ciphertext, &mut plaintext)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
total_bytes += payload.len();
|
total_bytes += payload.len();
|
||||||
|
|||||||
@@ -24,8 +24,14 @@ fn run_codec() {
|
|||||||
print_header("Codec Roundtrip (Opus 24kbps)");
|
print_header("Codec Roundtrip (Opus 24kbps)");
|
||||||
let r = bench::bench_codec_roundtrip();
|
let r = bench::bench_codec_roundtrip();
|
||||||
print_row("Frames", &format!("{}", r.frames));
|
print_row("Frames", &format!("{}", r.frames));
|
||||||
print_row("Encode total", &format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0));
|
print_row(
|
||||||
print_row("Decode total", &format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0));
|
"Encode total",
|
||||||
|
&format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
|
print_row(
|
||||||
|
"Decode total",
|
||||||
|
&format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us));
|
print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us));
|
||||||
print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us));
|
print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us));
|
||||||
print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec));
|
print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec));
|
||||||
@@ -41,7 +47,10 @@ fn run_fec(loss_pct: f32) {
|
|||||||
print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct));
|
print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct));
|
||||||
print_row("Source bytes", &format!("{}", r.total_source_bytes));
|
print_row("Source bytes", &format!("{}", r.total_source_bytes));
|
||||||
print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes));
|
print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes));
|
||||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
print_row(
|
||||||
|
"Total time",
|
||||||
|
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
print_footer();
|
print_footer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,7 +58,10 @@ fn run_crypto() {
|
|||||||
print_header("Crypto (ChaCha20-Poly1305)");
|
print_header("Crypto (ChaCha20-Poly1305)");
|
||||||
let r = bench::bench_encrypt_decrypt();
|
let r = bench::bench_encrypt_decrypt();
|
||||||
print_row("Packets", &format!("{}", r.packets));
|
print_row("Packets", &format!("{}", r.packets));
|
||||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
print_row(
|
||||||
|
"Total time",
|
||||||
|
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec));
|
print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec));
|
||||||
print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec));
|
print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec));
|
||||||
print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us));
|
print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us));
|
||||||
@@ -60,9 +72,18 @@ fn run_pipeline() {
|
|||||||
print_header("Full Pipeline (E2E)");
|
print_header("Full Pipeline (E2E)");
|
||||||
let r = bench::bench_full_pipeline();
|
let r = bench::bench_full_pipeline();
|
||||||
print_row("Frames", &format!("{}", r.frames));
|
print_row("Frames", &format!("{}", r.frames));
|
||||||
print_row("Encode pipeline", &format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0));
|
print_row(
|
||||||
print_row("Decode pipeline", &format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0));
|
"Encode pipeline",
|
||||||
print_row("Avg E2E latency", &format!("{:.1} us/frame", r.avg_e2e_latency_us));
|
&format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
|
print_row(
|
||||||
|
"Decode pipeline",
|
||||||
|
&format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0),
|
||||||
|
);
|
||||||
|
print_row(
|
||||||
|
"Avg E2E latency",
|
||||||
|
&format!("{:.1} us/frame", r.avg_e2e_latency_us),
|
||||||
|
);
|
||||||
print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in));
|
print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in));
|
||||||
print_row("Wire out", &format!("{} bytes", r.wire_bytes_out));
|
print_row("Wire out", &format!("{} bytes", r.wire_bytes_out));
|
||||||
print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio));
|
print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio));
|
||||||
|
|||||||
347
crates/wzp-client/src/birthday.rs
Normal file
347
crates/wzp-client/src/birthday.rs
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
//! Birthday attack for hard NAT traversal.
|
||||||
|
//!
|
||||||
|
//! When both peers are behind symmetric NATs with random port
|
||||||
|
//! allocation, standard hole-punching fails because neither side
|
||||||
|
//! can predict the other's external port. This module implements
|
||||||
|
//! the birthday-paradox approach:
|
||||||
|
//!
|
||||||
|
//! 1. **Acceptor** opens N sockets, STUN-probes each to learn
|
||||||
|
//! their external ports, reports them to the Dialer.
|
||||||
|
//! 2. **Dialer** sprays QUIC connect attempts to the Acceptor's
|
||||||
|
//! reported ports + random ports on the Acceptor's IP.
|
||||||
|
//! 3. Birthday paradox: with N=64 ports and M=256 probes across
|
||||||
|
//! 65536 ports, collision probability is high.
|
||||||
|
//!
|
||||||
|
//! In practice, the Acceptor's STUN-probed ports are known
|
||||||
|
//! exactly (not random), so the Dialer targets them first —
|
||||||
|
//! making this more like "spray-and-pray with a hit list" than
|
||||||
|
//! a pure birthday attack.
|
||||||
|
|
||||||
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use crate::stun;
|
||||||
|
|
||||||
|
/// Configuration for the birthday attack.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BirthdayConfig {
|
||||||
|
/// Number of sockets the Acceptor opens (default: 32).
|
||||||
|
/// Each socket gets STUN-probed to learn its external port.
|
||||||
|
/// More = higher chance of collision, but more resource usage.
|
||||||
|
pub acceptor_ports: u16,
|
||||||
|
/// Number of QUIC connect attempts the Dialer makes (default: 128).
|
||||||
|
/// Spread across the Acceptor's known ports + random ports.
|
||||||
|
pub dialer_probes: u16,
|
||||||
|
/// Rate limit: ms between consecutive probes (default: 20ms = 50/s).
|
||||||
|
pub probe_interval_ms: u16,
|
||||||
|
/// Overall timeout for the birthday attack phase.
|
||||||
|
pub timeout: Duration,
|
||||||
|
/// STUN config for probing external ports.
|
||||||
|
pub stun_config: stun::StunConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BirthdayConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
acceptor_ports: 32,
|
||||||
|
dialer_probes: 128,
|
||||||
|
probe_interval_ms: 20,
|
||||||
|
timeout: Duration::from_secs(8),
|
||||||
|
stun_config: stun::StunConfig {
|
||||||
|
servers: vec!["stun.l.google.com:19302".into()],
|
||||||
|
timeout: Duration::from_secs(2),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of the Acceptor's port-opening phase.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct AcceptorPorts {
|
||||||
|
/// External IP (from STUN).
|
||||||
|
pub external_ip: Option<Ipv4Addr>,
|
||||||
|
/// List of (local_port, external_port) for each opened socket.
|
||||||
|
pub ports: Vec<PortMapping>,
|
||||||
|
/// How many sockets we attempted to open.
|
||||||
|
pub attempted: u16,
|
||||||
|
/// How many STUN probes succeeded.
|
||||||
|
pub succeeded: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single socket's local↔external port mapping.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct PortMapping {
|
||||||
|
pub local_port: u16,
|
||||||
|
pub external_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open N sockets and STUN-probe each to discover external ports.
|
||||||
|
///
|
||||||
|
/// Returns the set of known external ports that the Dialer should
|
||||||
|
/// target. Each socket stays open (bound) so the NAT mapping
|
||||||
|
/// remains active until the returned `PortGuard` is dropped.
|
||||||
|
///
|
||||||
|
/// The sockets are returned so the caller can keep them alive
|
||||||
|
/// during the attack. Dropping them closes the NAT pinholes.
|
||||||
|
pub async fn open_acceptor_ports(
|
||||||
|
config: &BirthdayConfig,
|
||||||
|
) -> (AcceptorPorts, Vec<tokio::net::UdpSocket>) {
|
||||||
|
let mut sockets = Vec::new();
|
||||||
|
let mut mappings = Vec::new();
|
||||||
|
let mut external_ip: Option<Ipv4Addr> = None;
|
||||||
|
let mut succeeded: u16 = 0;
|
||||||
|
|
||||||
|
let stun_server = match config.stun_config.servers.first() {
|
||||||
|
Some(s) => match stun::resolve_stun_server(s).await {
|
||||||
|
Ok(a) => Some(a),
|
||||||
|
Err(_) => None,
|
||||||
|
},
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
for _ in 0..config.acceptor_ports {
|
||||||
|
// Bind to random port
|
||||||
|
let sock = match tokio::net::UdpSocket::bind("0.0.0.0:0").await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let local_port = match sock.local_addr() {
|
||||||
|
Ok(a) => a.port(),
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
// STUN probe to learn external port
|
||||||
|
if let Some(stun_addr) = stun_server {
|
||||||
|
match stun::stun_reflect(&sock, stun_addr, config.stun_config.timeout).await {
|
||||||
|
Ok(ext_addr) => {
|
||||||
|
if external_ip.is_none() {
|
||||||
|
if let std::net::IpAddr::V4(ip) = ext_addr.ip() {
|
||||||
|
external_ip = Some(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mappings.push(PortMapping {
|
||||||
|
local_port,
|
||||||
|
external_port: ext_addr.port(),
|
||||||
|
});
|
||||||
|
succeeded += 1;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(local_port, error = %e, "birthday: STUN probe failed for socket");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sockets.push(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
attempted = config.acceptor_ports,
|
||||||
|
succeeded,
|
||||||
|
external_ip = ?external_ip,
|
||||||
|
"birthday: acceptor ports opened"
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = AcceptorPorts {
|
||||||
|
external_ip,
|
||||||
|
ports: mappings,
|
||||||
|
attempted: config.acceptor_ports,
|
||||||
|
succeeded,
|
||||||
|
};
|
||||||
|
|
||||||
|
(result, sockets)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate the list of target addresses for the Dialer to spray.
|
||||||
|
///
|
||||||
|
/// Priority order:
|
||||||
|
/// 1. Acceptor's known external ports (from STUN probes) — highest hit rate
|
||||||
|
/// 2. Random ports on the Acceptor's IP — birthday paradox fill
|
||||||
|
pub fn generate_dialer_targets(
|
||||||
|
acceptor_ip: Ipv4Addr,
|
||||||
|
known_ports: &[u16],
|
||||||
|
total_probes: u16,
|
||||||
|
) -> Vec<SocketAddr> {
|
||||||
|
let mut targets = Vec::with_capacity(total_probes as usize);
|
||||||
|
|
||||||
|
// First: all known ports (guaranteed targets)
|
||||||
|
for &port in known_ports {
|
||||||
|
targets.push(SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill remaining with random ports (birthday attack)
|
||||||
|
let remaining = total_probes.saturating_sub(known_ports.len() as u16);
|
||||||
|
if remaining > 0 {
|
||||||
|
use rand::Rng;
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
for _ in 0..remaining {
|
||||||
|
let port = rng.gen_range(1024..=65535u16);
|
||||||
|
let addr = SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port);
|
||||||
|
if !targets.contains(&addr) {
|
||||||
|
targets.push(addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
targets
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the Dialer side of the birthday attack.
|
||||||
|
///
|
||||||
|
/// Sprays QUIC connection attempts at the target addresses.
|
||||||
|
/// Returns the first successful connection, or None on timeout.
|
||||||
|
pub async fn spray_dialer(
|
||||||
|
endpoint: &wzp_transport::Endpoint,
|
||||||
|
targets: &[SocketAddr],
|
||||||
|
call_sni: &str,
|
||||||
|
probe_interval: Duration,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> Option<wzp_transport::QuinnTransport> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
target_count = targets.len(),
|
||||||
|
interval_ms = probe_interval.as_millis(),
|
||||||
|
timeout_s = timeout.as_secs(),
|
||||||
|
"birthday: dialer starting spray"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Spray connects with rate limiting
|
||||||
|
for (idx, &target) in targets.iter().enumerate() {
|
||||||
|
if start.elapsed() >= timeout {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ep = endpoint.clone();
|
||||||
|
let sni = call_sni.to_string();
|
||||||
|
let client_cfg = wzp_transport::client_config();
|
||||||
|
set.spawn(async move {
|
||||||
|
let result = wzp_transport::connect(&ep, target, &sni, client_cfg).await;
|
||||||
|
(idx, target, result)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Rate limit — don't blast the NAT
|
||||||
|
if idx < targets.len() - 1 {
|
||||||
|
tokio::time::sleep(probe_interval).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
spawned = set.len(),
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"birthday: all probes spawned, waiting for first success"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for first success or all failures
|
||||||
|
let deadline = start + timeout;
|
||||||
|
while let Some(join_res) = tokio::select! {
|
||||||
|
r = set.join_next() => r,
|
||||||
|
_ = tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)) => None,
|
||||||
|
} {
|
||||||
|
match join_res {
|
||||||
|
Ok((idx, target, Ok(conn))) => {
|
||||||
|
tracing::info!(
|
||||||
|
idx,
|
||||||
|
%target,
|
||||||
|
remote = %conn.remote_address(),
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"birthday: HIT! QUIC handshake succeeded"
|
||||||
|
);
|
||||||
|
set.abort_all();
|
||||||
|
return Some(wzp_transport::QuinnTransport::new(conn));
|
||||||
|
}
|
||||||
|
Ok((idx, target, Err(e))) => {
|
||||||
|
tracing::debug!(
|
||||||
|
idx,
|
||||||
|
%target,
|
||||||
|
error = %e,
|
||||||
|
"birthday: probe failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"birthday: all probes failed or timed out"
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generate_targets_known_ports_first() {
|
||||||
|
let ip = Ipv4Addr::new(203, 0, 113, 5);
|
||||||
|
let known = vec![10000, 10001, 10002];
|
||||||
|
let targets = generate_dialer_targets(ip, &known, 10);
|
||||||
|
|
||||||
|
// Known ports should be first
|
||||||
|
assert_eq!(targets[0].port(), 10000);
|
||||||
|
assert_eq!(targets[1].port(), 10001);
|
||||||
|
assert_eq!(targets[2].port(), 10002);
|
||||||
|
// Rest are random
|
||||||
|
assert!(targets.len() <= 10);
|
||||||
|
// All target the right IP
|
||||||
|
assert!(targets.iter().all(|a| a.ip() == std::net::IpAddr::V4(ip)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generate_targets_no_known_all_random() {
|
||||||
|
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||||
|
let targets = generate_dialer_targets(ip, &[], 50);
|
||||||
|
assert!(!targets.is_empty());
|
||||||
|
assert!(targets.len() <= 50);
|
||||||
|
// All ports in valid range
|
||||||
|
assert!(targets.iter().all(|a| a.port() >= 1024));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generate_targets_more_known_than_total() {
|
||||||
|
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||||
|
let known: Vec<u16> = (10000..10100).collect();
|
||||||
|
let targets = generate_dialer_targets(ip, &known, 50);
|
||||||
|
// All 100 known ports included even though total=50
|
||||||
|
assert_eq!(targets.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generate_targets_dedup() {
|
||||||
|
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||||
|
let targets = generate_dialer_targets(ip, &[], 100);
|
||||||
|
// No duplicates
|
||||||
|
let mut sorted = targets.clone();
|
||||||
|
sorted.sort();
|
||||||
|
sorted.dedup();
|
||||||
|
assert_eq!(sorted.len(), targets.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_config() {
|
||||||
|
let cfg = BirthdayConfig::default();
|
||||||
|
assert_eq!(cfg.acceptor_ports, 32);
|
||||||
|
assert_eq!(cfg.dialer_probes, 128);
|
||||||
|
assert!(cfg.timeout.as_secs() > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acceptor_ports_serializes() {
|
||||||
|
let result = AcceptorPorts {
|
||||||
|
external_ip: Some(Ipv4Addr::new(203, 0, 113, 5)),
|
||||||
|
ports: vec![PortMapping {
|
||||||
|
local_port: 12345,
|
||||||
|
external_port: 54321,
|
||||||
|
}],
|
||||||
|
attempted: 32,
|
||||||
|
succeeded: 1,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&result).unwrap();
|
||||||
|
assert!(json.contains("54321"));
|
||||||
|
assert!(json.contains("203.0.113.5"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ use wzp_codec::{
|
|||||||
};
|
};
|
||||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||||
|
use wzp_proto::packet::QualityReport;
|
||||||
use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext};
|
use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext};
|
||||||
use wzp_proto::quality::AdaptiveQualityController;
|
use wzp_proto::quality::AdaptiveQualityController;
|
||||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||||
use wzp_proto::packet::QualityReport;
|
use wzp_proto::{CodecId, MediaType, QualityProfile};
|
||||||
use wzp_proto::{CodecId, QualityProfile};
|
|
||||||
|
|
||||||
/// Configuration for a call session.
|
/// Configuration for a call session.
|
||||||
pub struct CallConfig {
|
pub struct CallConfig {
|
||||||
@@ -205,7 +205,7 @@ pub struct CallEncoder {
|
|||||||
/// Current profile.
|
/// Current profile.
|
||||||
profile: QualityProfile,
|
profile: QualityProfile,
|
||||||
/// Outbound sequence counter.
|
/// Outbound sequence counter.
|
||||||
seq: u16,
|
seq: u32,
|
||||||
/// Current FEC block.
|
/// Current FEC block.
|
||||||
block_id: u8,
|
block_id: u8,
|
||||||
/// Frame index within current block.
|
/// Frame index within current block.
|
||||||
@@ -234,6 +234,8 @@ pub struct CallEncoder {
|
|||||||
mini_frames_enabled: bool,
|
mini_frames_enabled: bool,
|
||||||
/// Frames encoded since the last full header was emitted.
|
/// Frames encoded since the last full header was emitted.
|
||||||
frames_since_full: u32,
|
frames_since_full: u32,
|
||||||
|
/// Pending quality report to attach to the next source packet.
|
||||||
|
pending_quality_report: Option<QualityReport>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CallEncoder {
|
impl CallEncoder {
|
||||||
@@ -264,6 +266,7 @@ impl CallEncoder {
|
|||||||
mini_context: MiniFrameContext::default(),
|
mini_context: MiniFrameContext::default(),
|
||||||
mini_frames_enabled: config.mini_frames_enabled,
|
mini_frames_enabled: config.mini_frames_enabled,
|
||||||
frames_since_full: 0,
|
frames_since_full: 0,
|
||||||
|
pending_quality_report: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,17 +318,15 @@ impl CallEncoder {
|
|||||||
if self.cn_counter % 10 == 0 {
|
if self.cn_counter % 10 == 0 {
|
||||||
let cn_pkt = MediaPacket {
|
let cn_pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: CodecId::ComfortNoise,
|
codec_id: CodecId::ComfortNoise,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq: self.seq,
|
seq: self.seq,
|
||||||
timestamp: self.timestamp_ms,
|
timestamp: self.timestamp_ms,
|
||||||
fec_block: self.block_id,
|
fec_block: u16::from(self.block_id),
|
||||||
fec_symbol: 0,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(vec![self.cn_level as u8]),
|
payload: Bytes::from(vec![self.cn_level as u8]),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -351,33 +352,34 @@ impl CallEncoder {
|
|||||||
// can cleanly identify "no RaptorQ block to assemble" and new
|
// can cleanly identify "no RaptorQ block to assemble" and new
|
||||||
// receivers can short-circuit their FEC ingest path.
|
// receivers can short-circuit their FEC ingest path.
|
||||||
let is_opus = self.profile.codec.is_opus();
|
let is_opus = self.profile.codec.is_opus();
|
||||||
let (fec_block, fec_symbol, fec_ratio_encoded) = if is_opus {
|
let (fec_block, fec_ratio) = if is_opus {
|
||||||
(0u8, 0u8, 0u8)
|
(0u16, 0u8)
|
||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
self.block_id,
|
u16::from(self.block_id) | (u16::from(self.frame_in_block) << 8),
|
||||||
self.frame_in_block,
|
|
||||||
MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build source media packet
|
// Build source media packet
|
||||||
|
let mut flags = 0u8;
|
||||||
|
if self.pending_quality_report.is_some() {
|
||||||
|
flags |= MediaHeader::FLAG_QUALITY;
|
||||||
|
}
|
||||||
let source_pkt = MediaPacket {
|
let source_pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: self.profile.codec,
|
codec_id: self.profile.codec,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded,
|
fec_ratio,
|
||||||
seq: self.seq,
|
seq: self.seq,
|
||||||
timestamp: self.timestamp_ms,
|
timestamp: self.timestamp_ms,
|
||||||
fec_block,
|
fec_block,
|
||||||
fec_symbol,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(encoded.clone()),
|
payload: Bytes::from(encoded.clone()),
|
||||||
quality_report: None,
|
quality_report: self.pending_quality_report.take(),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.seq = self.seq.wrapping_add(1);
|
self.seq = self.seq.wrapping_add(1);
|
||||||
@@ -399,19 +401,15 @@ impl CallEncoder {
|
|||||||
for (sym_idx, repair_data) in repairs {
|
for (sym_idx, repair_data) in repairs {
|
||||||
output.push(MediaPacket {
|
output.push(MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: true,
|
flags: MediaHeader::FLAG_REPAIR,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: self.profile.codec,
|
codec_id: self.profile.codec,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||||
self.profile.fec_ratio,
|
|
||||||
),
|
|
||||||
seq: self.seq,
|
seq: self.seq,
|
||||||
timestamp: self.timestamp_ms,
|
timestamp: self.timestamp_ms,
|
||||||
fec_block: self.block_id,
|
fec_block: u16::from(self.block_id) | (sym_idx << 8),
|
||||||
fec_symbol: sym_idx,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(repair_data),
|
payload: Bytes::from(repair_data),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -445,6 +443,22 @@ impl CallEncoder {
|
|||||||
self.aec.feed_farend(farend);
|
self.aec.feed_farend(farend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply DRED tuning output to the encoder.
|
||||||
|
///
|
||||||
|
/// Called by the send loop after `DredTuner::update()` returns `Some`.
|
||||||
|
/// No-op when the active codec is Codec2 (DRED is Opus-only).
|
||||||
|
pub fn apply_dred_tuning(&mut self, tuning: wzp_proto::DredTuning) {
|
||||||
|
self.audio_enc.set_dred_duration(tuning.dred_frames);
|
||||||
|
self.audio_enc.set_expected_loss(tuning.expected_loss_pct);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Queue a quality report for attachment to the next source packet.
|
||||||
|
/// Used by the send task to embed locally-observed path quality so
|
||||||
|
/// the peer can drive adaptive quality switching.
|
||||||
|
pub fn set_pending_quality_report(&mut self, report: QualityReport) {
|
||||||
|
self.pending_quality_report = Some(report);
|
||||||
|
}
|
||||||
|
|
||||||
/// Enable or disable acoustic echo cancellation.
|
/// Enable or disable acoustic echo cancellation.
|
||||||
pub fn set_aec_enabled(&mut self, enabled: bool) {
|
pub fn set_aec_enabled(&mut self, enabled: bool) {
|
||||||
self.aec.set_enabled(enabled);
|
self.aec.set_enabled(enabled);
|
||||||
@@ -489,7 +503,7 @@ pub struct CallDecoder {
|
|||||||
last_good_dred: DredState,
|
last_good_dred: DredState,
|
||||||
/// Sequence number of the packet that produced `last_good_dred`. `None`
|
/// Sequence number of the packet that produced `last_good_dred`. `None`
|
||||||
/// if no packet has yielded DRED state yet (cold start or legacy sender).
|
/// if no packet has yielded DRED state yet (cold start or legacy sender).
|
||||||
last_good_dred_seq: Option<u16>,
|
last_good_dred_seq: Option<u32>,
|
||||||
/// Phase 4 telemetry counter: gaps recovered via DRED reconstruction.
|
/// Phase 4 telemetry counter: gaps recovered via DRED reconstruction.
|
||||||
pub dred_reconstructions: u64,
|
pub dred_reconstructions: u64,
|
||||||
/// Phase 4 telemetry counter: gaps filled via classical Opus PLC
|
/// Phase 4 telemetry counter: gaps filled via classical Opus PLC
|
||||||
@@ -552,8 +566,8 @@ impl CallDecoder {
|
|||||||
if !packet.header.codec_id.is_opus() {
|
if !packet.header.codec_id.is_opus() {
|
||||||
let _ = self.fec_dec.add_symbol(
|
let _ = self.fec_dec.add_symbol(
|
||||||
packet.header.fec_block,
|
packet.header.fec_block,
|
||||||
packet.header.fec_symbol,
|
packet.header.fec_block >> 8,
|
||||||
packet.header.is_repair,
|
packet.header.is_repair(),
|
||||||
&packet.payload,
|
&packet.payload,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -563,7 +577,7 @@ impl CallDecoder {
|
|||||||
// swap with the cached `last_good_dred` so later gap reconstruction
|
// swap with the cached `last_good_dred` so later gap reconstruction
|
||||||
// has fresh neural redundancy to draw from. Parsing happens before
|
// has fresh neural redundancy to draw from. Parsing happens before
|
||||||
// the jitter push because the jitter buffer consumes the packet.
|
// the jitter push because the jitter buffer consumes the packet.
|
||||||
if packet.header.codec_id.is_opus() && !packet.header.is_repair {
|
if packet.header.codec_id.is_opus() && !packet.header.is_repair() {
|
||||||
match self
|
match self
|
||||||
.dred_decoder
|
.dred_decoder
|
||||||
.parse_into(&mut self.dred_parse_scratch, &packet.payload)
|
.parse_into(&mut self.dred_parse_scratch, &packet.payload)
|
||||||
@@ -592,7 +606,7 @@ impl CallDecoder {
|
|||||||
// Source packets (Opus or Codec2) go to the jitter buffer for decode.
|
// Source packets (Opus or Codec2) go to the jitter buffer for decode.
|
||||||
// Repair packets never reach the jitter buffer; for Codec2 they're
|
// Repair packets never reach the jitter buffer; for Codec2 they're
|
||||||
// used by the FEC decoder above, for Opus they're dropped here.
|
// used by the FEC decoder above, for Opus they're dropped here.
|
||||||
if !packet.header.is_repair {
|
if !packet.header.is_repair() {
|
||||||
self.jitter.push(packet);
|
self.jitter.push(packet);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -627,6 +641,7 @@ impl CallDecoder {
|
|||||||
fec_ratio: 0.3,
|
fec_ratio: 0.3,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
..QualityProfile::GOOD
|
||||||
},
|
},
|
||||||
CodecId::Opus6k => QualityProfile::DEGRADED,
|
CodecId::Opus6k => QualityProfile::DEGRADED,
|
||||||
CodecId::Opus32k => QualityProfile::STUDIO_32K,
|
CodecId::Opus32k => QualityProfile::STUDIO_32K,
|
||||||
@@ -637,9 +652,13 @@ impl CallDecoder {
|
|||||||
fec_ratio: 0.5,
|
fec_ratio: 0.5,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
..QualityProfile::GOOD
|
||||||
},
|
},
|
||||||
CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC,
|
CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC,
|
||||||
CodecId::ComfortNoise => QualityProfile::GOOD,
|
CodecId::ComfortNoise => QualityProfile::GOOD,
|
||||||
|
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => {
|
||||||
|
panic!("video codec passed to audio decoder")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -692,12 +711,12 @@ impl CallDecoder {
|
|||||||
if let Some(last_seq) = self.last_good_dred_seq {
|
if let Some(last_seq) = self.last_good_dred_seq {
|
||||||
// How many frames ahead of the missing seq is the
|
// How many frames ahead of the missing seq is the
|
||||||
// last-good packet? Use wrapping arithmetic for the
|
// last-good packet? Use wrapping arithmetic for the
|
||||||
// u16 seq space.
|
// u32 seq space.
|
||||||
let seq_delta = last_seq.wrapping_sub(seq);
|
let seq_delta = last_seq.wrapping_sub(seq);
|
||||||
// Reject stale or backward state. u16 wraparound
|
// Reject stale or backward state. u32 wraparound
|
||||||
// would make a "seq went backward" delta very large;
|
// would make a "seq went backward" delta very large;
|
||||||
// cap at a sane forward-looking window.
|
// cap at a sane forward-looking window.
|
||||||
const MAX_SEQ_DELTA: u16 = 128;
|
const MAX_SEQ_DELTA: u32 = 128;
|
||||||
if seq_delta > 0 && seq_delta <= MAX_SEQ_DELTA {
|
if seq_delta > 0 && seq_delta <= MAX_SEQ_DELTA {
|
||||||
let frame_samples =
|
let frame_samples =
|
||||||
(48_000 * self.profile.frame_duration_ms as i32) / 1000;
|
(48_000 * self.profile.frame_duration_ms as i32) / 1000;
|
||||||
@@ -766,7 +785,7 @@ impl CallDecoder {
|
|||||||
/// Phase 3b introspection: sequence number of the most recently parsed
|
/// Phase 3b introspection: sequence number of the most recently parsed
|
||||||
/// valid DRED state, or `None` if no Opus packet has yielded DRED data
|
/// valid DRED state, or `None` if no Opus packet has yielded DRED data
|
||||||
/// yet. Used by tests to debug reconstruction eligibility.
|
/// yet. Used by tests to debug reconstruction eligibility.
|
||||||
pub fn last_good_dred_seq(&self) -> Option<u16> {
|
pub fn last_good_dred_seq(&self) -> Option<u32> {
|
||||||
self.last_good_dred_seq
|
self.last_good_dred_seq
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -833,7 +852,7 @@ mod tests {
|
|||||||
let packets = enc.encode_frame(&pcm).unwrap();
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
assert!(!packets.is_empty());
|
assert!(!packets.is_empty());
|
||||||
assert_eq!(packets[0].header.seq, 0);
|
assert_eq!(packets[0].header.seq, 0);
|
||||||
assert!(!packets[0].header.is_repair);
|
assert!(!packets[0].header.is_repair());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Phase 2: Opus packets have zero FEC header fields — no block, no
|
/// Phase 2: Opus packets have zero FEC header fields — no block, no
|
||||||
@@ -856,10 +875,9 @@ mod tests {
|
|||||||
assert_eq!(packets.len(), 1, "Opus must emit exactly 1 source packet");
|
assert_eq!(packets.len(), 1, "Opus must emit exactly 1 source packet");
|
||||||
let hdr = &packets[0].header;
|
let hdr = &packets[0].header;
|
||||||
assert!(hdr.codec_id.is_opus());
|
assert!(hdr.codec_id.is_opus());
|
||||||
assert!(!hdr.is_repair);
|
assert!(!hdr.is_repair());
|
||||||
assert_eq!(hdr.fec_block, 0, "Opus fec_block must be 0");
|
assert_eq!(hdr.fec_block, 0, "Opus fec_block must be 0");
|
||||||
assert_eq!(hdr.fec_symbol, 0, "Opus fec_symbol must be 0");
|
assert_eq!(hdr.fec_ratio, 0, "Opus fec_ratio must be 0");
|
||||||
assert_eq!(hdr.fec_ratio_encoded, 0, "Opus fec_ratio_encoded must be 0");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Phase 2: Opus never emits repair packets, regardless of how many
|
/// Phase 2: Opus never emits repair packets, regardless of how many
|
||||||
@@ -883,7 +901,7 @@ mod tests {
|
|||||||
for _ in 0..20 {
|
for _ in 0..20 {
|
||||||
let packets = enc.encode_frame(&pcm).unwrap();
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
total_packets += packets.len();
|
total_packets += packets.len();
|
||||||
repair_count += packets.iter().filter(|p| p.header.is_repair).count();
|
repair_count += packets.iter().filter(|p| p.header.is_repair()).count();
|
||||||
}
|
}
|
||||||
assert_eq!(repair_count, 0, "Opus must emit zero repair packets");
|
assert_eq!(repair_count, 0, "Opus must emit zero repair packets");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -915,7 +933,7 @@ mod tests {
|
|||||||
for _ in 0..16 {
|
for _ in 0..16 {
|
||||||
let packets = enc.encode_frame(&pcm).unwrap();
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
for p in &packets {
|
for p in &packets {
|
||||||
if p.header.is_repair {
|
if p.header.is_repair() {
|
||||||
repair_count += 1;
|
repair_count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -934,17 +952,15 @@ mod tests {
|
|||||||
|
|
||||||
let pkt = MediaPacket {
|
let pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: CodecId::Opus24k,
|
codec_id: CodecId::Opus24k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq: 0,
|
seq: 0,
|
||||||
timestamp: 0,
|
timestamp: 0,
|
||||||
fec_block: 0,
|
fec_block: 0,
|
||||||
fec_symbol: 0,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(vec![0u8; 60]),
|
payload: Bytes::from(vec![0u8; 60]),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -1006,17 +1022,15 @@ mod tests {
|
|||||||
encoded.truncate(n);
|
encoded.truncate(n);
|
||||||
let pkt = MediaPacket {
|
let pkt = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: CodecId::Opus24k,
|
codec_id: CodecId::Opus24k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq: i,
|
seq: i as u32,
|
||||||
timestamp: (i as u32) * 20,
|
timestamp: (i as u32) * 20,
|
||||||
fec_block: 0,
|
fec_block: 0,
|
||||||
fec_symbol: 0,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(encoded),
|
payload: Bytes::from(encoded),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -1086,9 +1100,7 @@ mod tests {
|
|||||||
|
|
||||||
let dred_delta = dec.dred_reconstructions - baseline_dred;
|
let dred_delta = dec.dred_reconstructions - baseline_dred;
|
||||||
let plc_delta = dec.classical_plc_invocations - baseline_plc;
|
let plc_delta = dec.classical_plc_invocations - baseline_plc;
|
||||||
eprintln!(
|
eprintln!("[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}");
|
||||||
"[phase3b probe] post-drain: dred_delta={dred_delta} plc_delta={plc_delta}"
|
|
||||||
);
|
|
||||||
assert!(
|
assert!(
|
||||||
dred_delta >= 1,
|
dred_delta >= 1,
|
||||||
"expected ≥1 DRED reconstruction on single-packet loss, \
|
"expected ≥1 DRED reconstruction on single-packet loss, \
|
||||||
@@ -1149,7 +1161,7 @@ mod tests {
|
|||||||
let packets = enc.encode_frame(&pcm).unwrap();
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
for pkt in packets {
|
for pkt in packets {
|
||||||
// Drop every 5th source packet to simulate loss.
|
// Drop every 5th source packet to simulate loss.
|
||||||
if !pkt.header.is_repair && i % 5 == 3 {
|
if !pkt.header.is_repair() && i % 5 == 3 {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
dec.ingest(pkt);
|
dec.ingest(pkt);
|
||||||
@@ -1303,20 +1315,18 @@ mod tests {
|
|||||||
|
|
||||||
// ---- JitterStats telemetry tests ----
|
// ---- JitterStats telemetry tests ----
|
||||||
|
|
||||||
fn make_test_packet(seq: u16) -> MediaPacket {
|
fn make_test_packet(seq: u32) -> MediaPacket {
|
||||||
MediaPacket {
|
MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: CodecId::Opus24k,
|
codec_id: CodecId::Opus24k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq,
|
seq,
|
||||||
timestamp: seq as u32 * 20,
|
timestamp: seq * 20,
|
||||||
fec_block: 0,
|
fec_block: 0,
|
||||||
fec_symbol: seq as u8,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(vec![0u8; 60]),
|
payload: Bytes::from(vec![0u8; 60]),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -1328,7 +1338,7 @@ mod tests {
|
|||||||
let config = CallConfig::default();
|
let config = CallConfig::default();
|
||||||
let mut dec = CallDecoder::new(&config);
|
let mut dec = CallDecoder::new(&config);
|
||||||
|
|
||||||
for i in 0..5u16 {
|
for i in 0..5u32 {
|
||||||
dec.ingest(make_test_packet(i));
|
dec.ingest(make_test_packet(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1358,7 +1368,7 @@ mod tests {
|
|||||||
let mut dec = CallDecoder::new(&config);
|
let mut dec = CallDecoder::new(&config);
|
||||||
|
|
||||||
// Generate some stats: ingest packets and trigger underruns on empty buffer
|
// Generate some stats: ingest packets and trigger underruns on empty buffer
|
||||||
for i in 0..3u16 {
|
for i in 0..3u32 {
|
||||||
dec.ingest(make_test_packet(i));
|
dec.ingest(make_test_packet(i));
|
||||||
}
|
}
|
||||||
// Also call decode on empty decoder to get underruns
|
// Also call decode on empty decoder to get underruns
|
||||||
@@ -1437,9 +1447,225 @@ mod tests {
|
|||||||
cn_packets >= 1,
|
cn_packets >= 1,
|
||||||
"should have at least one CN packet, got {cn_packets}"
|
"should have at least one CN packet, got {cn_packets}"
|
||||||
);
|
);
|
||||||
|
assert!(enc.frames_suppressed > 0, "frames_suppressed should be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- DredTuner integration tests ----
|
||||||
|
|
||||||
|
/// End-to-end test: DredTuner reacts to simulated network degradation
|
||||||
|
/// and adjusts the encoder's DRED parameters via `apply_dred_tuning`.
|
||||||
|
#[test]
|
||||||
|
fn dred_tuner_adjusts_encoder_on_loss() {
|
||||||
|
use wzp_proto::DredTuner;
|
||||||
|
|
||||||
|
let mut enc = CallEncoder::new(&CallConfig {
|
||||||
|
profile: QualityProfile::GOOD,
|
||||||
|
suppression_enabled: false,
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
let mut tuner = DredTuner::new(QualityProfile::GOOD.codec);
|
||||||
|
|
||||||
|
// Baseline: good network → baseline DRED (20 frames = 200 ms).
|
||||||
|
let baseline = tuner.current();
|
||||||
|
assert_eq!(baseline.dred_frames, 20);
|
||||||
|
|
||||||
|
// Warm up the tuner — first few updates may return Some as the
|
||||||
|
// EWMA initializes and expected_loss settles from the initial 15%.
|
||||||
|
for _ in 0..10 {
|
||||||
|
tuner.update(0.0, 50, 5);
|
||||||
|
}
|
||||||
|
// After settling, the tuning should be at baseline.
|
||||||
|
assert_eq!(tuner.current().dred_frames, 20);
|
||||||
|
|
||||||
|
// Simulate network degradation: 30% loss, 300ms RTT.
|
||||||
|
// The tuner should increase DRED frames above baseline.
|
||||||
|
let tuning = tuner.update(30.0, 300, 15);
|
||||||
|
assert!(tuning.is_some(), "loss spike should trigger tuning change");
|
||||||
|
let t = tuning.unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
enc.frames_suppressed > 0,
|
t.dred_frames > 20,
|
||||||
"frames_suppressed should be > 0"
|
"30% loss should increase DRED above baseline 20, got {}",
|
||||||
|
t.dred_frames
|
||||||
|
);
|
||||||
|
|
||||||
|
// Apply to encoder — should not panic.
|
||||||
|
enc.apply_dred_tuning(t);
|
||||||
|
|
||||||
|
// Verify the encoder still works after tuning.
|
||||||
|
let pcm = voice_frame_20ms(0);
|
||||||
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
|
assert!(
|
||||||
|
!packets.is_empty(),
|
||||||
|
"encoder must still produce packets after DRED tuning"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DredTuner jitter spike triggers pre-emptive DRED boost to ceiling.
|
||||||
|
#[test]
|
||||||
|
fn dred_tuner_spike_boosts_to_ceiling() {
|
||||||
|
use wzp_proto::DredTuner;
|
||||||
|
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Establish low-jitter baseline.
|
||||||
|
for _ in 0..20 {
|
||||||
|
tuner.update(0.0, 50, 5);
|
||||||
|
}
|
||||||
|
assert!(!tuner.spike_boost_active());
|
||||||
|
|
||||||
|
// Jitter spikes to 40ms (8x baseline of ~5ms).
|
||||||
|
let tuning = tuner.update(0.0, 50, 40);
|
||||||
|
assert!(
|
||||||
|
tuner.spike_boost_active(),
|
||||||
|
"jitter spike should activate boost"
|
||||||
|
);
|
||||||
|
assert!(tuning.is_some());
|
||||||
|
// Ceiling for Opus24k is 50 frames = 500 ms.
|
||||||
|
assert_eq!(
|
||||||
|
tuning.unwrap().dred_frames,
|
||||||
|
50,
|
||||||
|
"spike should push to ceiling"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DredTuner is a no-op for Codec2 profiles.
|
||||||
|
#[test]
|
||||||
|
fn dred_tuner_noop_for_codec2() {
|
||||||
|
use wzp_proto::DredTuner;
|
||||||
|
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Codec2_1200);
|
||||||
|
|
||||||
|
// Even extreme conditions produce no tuning output.
|
||||||
|
assert!(tuner.update(50.0, 800, 100).is_none());
|
||||||
|
assert_eq!(tuner.current().dred_frames, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DredTuner + CallEncoder: full cycle through profile switch.
|
||||||
|
#[test]
|
||||||
|
fn dred_tuner_handles_profile_switch() {
|
||||||
|
use wzp_proto::DredTuner;
|
||||||
|
|
||||||
|
let mut enc = CallEncoder::new(&CallConfig {
|
||||||
|
profile: QualityProfile::GOOD,
|
||||||
|
suppression_enabled: false,
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
let mut tuner = DredTuner::new(QualityProfile::GOOD.codec);
|
||||||
|
|
||||||
|
// Apply initial tuning on good network.
|
||||||
|
if let Some(t) = tuner.update(0.0, 50, 5) {
|
||||||
|
enc.apply_dred_tuning(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to degraded profile.
|
||||||
|
enc.set_profile(QualityProfile::DEGRADED).unwrap();
|
||||||
|
tuner.set_codec(QualityProfile::DEGRADED.codec);
|
||||||
|
|
||||||
|
// Opus6k baseline is 50 frames (500 ms), ceiling is 104 (1040 ms).
|
||||||
|
let baseline = tuner.current();
|
||||||
|
// After set_codec, the cached tuning should reflect old state;
|
||||||
|
// a fresh update gives the new codec's mapping.
|
||||||
|
let tuning = tuner.update(20.0, 200, 10);
|
||||||
|
assert!(tuning.is_some());
|
||||||
|
let t = tuning.unwrap();
|
||||||
|
assert!(
|
||||||
|
t.dred_frames >= 50,
|
||||||
|
"Opus6k with 20% loss should be at least baseline 50, got {}",
|
||||||
|
t.dred_frames
|
||||||
|
);
|
||||||
|
|
||||||
|
enc.apply_dred_tuning(t);
|
||||||
|
|
||||||
|
// Encode a 40ms frame (Opus6k uses 40ms frames = 1920 samples).
|
||||||
|
let pcm: Vec<i16> = (0..1920)
|
||||||
|
.map(|i| ((i as f32 * 0.1).sin() * 10_000.0) as i16)
|
||||||
|
.collect();
|
||||||
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
|
assert!(!packets.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encoder_attaches_quality_report() {
|
||||||
|
let mut enc = CallEncoder::new(&CallConfig {
|
||||||
|
profile: QualityProfile::GOOD,
|
||||||
|
suppression_enabled: false,
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set a quality report
|
||||||
|
enc.set_pending_quality_report(QualityReport::from_path_stats(5.0, 80, 10));
|
||||||
|
|
||||||
|
// Encode a frame — should have quality_report attached
|
||||||
|
let pcm = voice_frame_20ms(0);
|
||||||
|
let packets = enc.encode_frame(&pcm).unwrap();
|
||||||
|
assert!(!packets.is_empty());
|
||||||
|
assert!(
|
||||||
|
packets[0].header.has_quality(),
|
||||||
|
"first packet should have quality report"
|
||||||
|
);
|
||||||
|
assert!(packets[0].quality_report.is_some());
|
||||||
|
|
||||||
|
// Next frame should NOT have quality_report (it was consumed)
|
||||||
|
let packets2 = enc.encode_frame(&voice_frame_20ms(960)).unwrap();
|
||||||
|
assert!(
|
||||||
|
!packets2[0].header.has_quality(),
|
||||||
|
"second packet should not have quality report"
|
||||||
|
);
|
||||||
|
assert!(packets2[0].quality_report.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quality_report_aead_tamper_fails_decrypt() {
|
||||||
|
use wzp_crypto::ChaChaSession;
|
||||||
|
use wzp_proto::CryptoSession;
|
||||||
|
|
||||||
|
// Build a packet with a QualityReport trailer.
|
||||||
|
let pkt = MediaPacket {
|
||||||
|
header: MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: MediaHeader::FLAG_QUALITY,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 10,
|
||||||
|
seq: 42,
|
||||||
|
timestamp: 1000,
|
||||||
|
fec_block: 0,
|
||||||
|
},
|
||||||
|
payload: Bytes::from(vec![0xAB; 60]),
|
||||||
|
quality_report: Some(QualityReport::from_path_stats(5.0, 80, 10)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Serialize: header || payload || quality_report
|
||||||
|
let wire = pkt.to_bytes();
|
||||||
|
assert_eq!(
|
||||||
|
wire.len(),
|
||||||
|
MediaHeader::WIRE_SIZE + pkt.payload.len() + QualityReport::WIRE_SIZE
|
||||||
|
);
|
||||||
|
|
||||||
|
let header_bytes = &wire[..MediaHeader::WIRE_SIZE];
|
||||||
|
let plaintext = &wire[MediaHeader::WIRE_SIZE..];
|
||||||
|
|
||||||
|
// Encrypt with ChaCha20-Poly1305 (header as AAD, payload+QR as plaintext).
|
||||||
|
let mut alice = ChaChaSession::new([0xAA; 32]);
|
||||||
|
let mut bob = ChaChaSession::new([0xAA; 32]);
|
||||||
|
let mut ciphertext = Vec::new();
|
||||||
|
alice
|
||||||
|
.encrypt(header_bytes, plaintext, &mut ciphertext)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Tamper with a byte in the QualityReport region (last 4 bytes of plaintext
|
||||||
|
// → last 4 bytes of ciphertext for ChaCha20 stream cipher).
|
||||||
|
let qr_offset_in_plaintext = plaintext.len() - QualityReport::WIRE_SIZE;
|
||||||
|
let tamper_idx = qr_offset_in_plaintext;
|
||||||
|
ciphertext[tamper_idx] ^= 0xFF;
|
||||||
|
|
||||||
|
// Decryption must fail because the AEAD tag no longer matches.
|
||||||
|
let mut decrypted = Vec::new();
|
||||||
|
let result = bob.decrypt(header_bytes, &ciphertext, &mut decrypted);
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"tampering with QualityReport inside AEAD payload must cause decryption failure"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
|||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
|
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
|
||||||
use wzp_proto::MediaTransport;
|
use wzp_proto::{MediaTransport, default_signal_version};
|
||||||
|
|
||||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||||
|
|
||||||
@@ -52,6 +52,8 @@ struct CliArgs {
|
|||||||
signal: bool,
|
signal: bool,
|
||||||
/// Place a direct call to a fingerprint (requires --signal).
|
/// Place a direct call to a fingerprint (requires --signal).
|
||||||
call_target: Option<String>,
|
call_target: Option<String>,
|
||||||
|
/// Run network diagnostic (STUN, port mapping, relay latencies).
|
||||||
|
netcheck: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CliArgs {
|
impl CliArgs {
|
||||||
@@ -97,6 +99,7 @@ fn parse_args() -> CliArgs {
|
|||||||
let mut relay_str = None;
|
let mut relay_str = None;
|
||||||
let mut signal = false;
|
let mut signal = false;
|
||||||
let mut call_target = None;
|
let mut call_target = None;
|
||||||
|
let mut netcheck = false;
|
||||||
|
|
||||||
let mut i = 1;
|
let mut i = 1;
|
||||||
while i < args.len() {
|
while i < args.len() {
|
||||||
@@ -105,7 +108,11 @@ fn parse_args() -> CliArgs {
|
|||||||
"--signal" => signal = true,
|
"--signal" => signal = true,
|
||||||
"--call" => {
|
"--call" => {
|
||||||
i += 1;
|
i += 1;
|
||||||
call_target = Some(args.get(i).expect("--call requires a fingerprint").to_string());
|
call_target = Some(
|
||||||
|
args.get(i)
|
||||||
|
.expect("--call requires a fingerprint")
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
"--send-tone" => {
|
"--send-tone" => {
|
||||||
i += 1;
|
i += 1;
|
||||||
@@ -182,7 +189,12 @@ fn parse_args() -> CliArgs {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
"--sweep" => sweep = true,
|
"--sweep" => sweep = true,
|
||||||
"--version-check" => { version_check = true; }
|
"--netcheck" => {
|
||||||
|
netcheck = true;
|
||||||
|
}
|
||||||
|
"--version-check" => {
|
||||||
|
version_check = true;
|
||||||
|
}
|
||||||
"--help" | "-h" => {
|
"--help" | "-h" => {
|
||||||
eprintln!("Usage: wzp-client [options] [relay-addr]");
|
eprintln!("Usage: wzp-client [options] [relay-addr]");
|
||||||
eprintln!();
|
eprintln!();
|
||||||
@@ -193,13 +205,19 @@ fn parse_args() -> CliArgs {
|
|||||||
eprintln!(" --record <file.raw> Record received audio to raw PCM file");
|
eprintln!(" --record <file.raw> Record received audio to raw PCM file");
|
||||||
eprintln!(" --echo-test <secs> Run automated echo quality test");
|
eprintln!(" --echo-test <secs> Run automated echo quality test");
|
||||||
eprintln!(" --drift-test <secs> Run automated clock-drift measurement");
|
eprintln!(" --drift-test <secs> Run automated clock-drift measurement");
|
||||||
eprintln!(" --sweep Run jitter buffer parameter sweep (local, no network)");
|
eprintln!(
|
||||||
eprintln!(" --seed <hex> Identity seed (64 hex chars, featherChat compatible)");
|
" --sweep Run jitter buffer parameter sweep (local, no network)"
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" --seed <hex> Identity seed (64 hex chars, featherChat compatible)"
|
||||||
|
);
|
||||||
eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)");
|
eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)");
|
||||||
eprintln!(" --room <name> Room name (hashed for privacy before sending)");
|
eprintln!(" --room <name> Room name (hashed for privacy before sending)");
|
||||||
eprintln!(" --token <token> featherChat bearer token for relay auth");
|
eprintln!(" --token <token> featherChat bearer token for relay auth");
|
||||||
eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)");
|
eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)");
|
||||||
eprintln!(" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)");
|
eprintln!(
|
||||||
|
" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)"
|
||||||
|
);
|
||||||
eprintln!();
|
eprintln!();
|
||||||
eprintln!("Default relay: 127.0.0.1:4433");
|
eprintln!("Default relay: 127.0.0.1:4433");
|
||||||
std::process::exit(0);
|
std::process::exit(0);
|
||||||
@@ -238,6 +256,7 @@ fn parse_args() -> CliArgs {
|
|||||||
version_check,
|
version_check,
|
||||||
signal,
|
signal,
|
||||||
call_target,
|
call_target,
|
||||||
|
netcheck,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,12 +275,28 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --netcheck: run network diagnostic and exit
|
||||||
|
if cli.netcheck {
|
||||||
|
let config = wzp_client::netcheck::NetcheckConfig {
|
||||||
|
stun_config: wzp_client::stun::StunConfig::default(),
|
||||||
|
relays: vec![("relay".into(), cli.relay_addr)],
|
||||||
|
timeout: std::time::Duration::from_secs(5),
|
||||||
|
test_portmap: true,
|
||||||
|
test_ipv6: true,
|
||||||
|
local_port: 0,
|
||||||
|
};
|
||||||
|
let report = wzp_client::netcheck::run_netcheck(&config).await;
|
||||||
|
print!("{}", wzp_client::netcheck::format_report(&report));
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
// --version-check: query relay version over QUIC and exit
|
// --version-check: query relay version over QUIC and exit
|
||||||
if cli.version_check {
|
if cli.version_check {
|
||||||
let client_config = wzp_transport::client_config();
|
let client_config = wzp_transport::client_config();
|
||||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse()?;
|
let bind_addr: SocketAddr = "0.0.0.0:0".parse()?;
|
||||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||||
let conn = wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
|
let conn =
|
||||||
|
wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
|
||||||
match conn.accept_uni().await {
|
match conn.accept_uni().await {
|
||||||
Ok(mut recv) => {
|
Ok(mut recv) => {
|
||||||
let data = recv.read_to_end(256).await.unwrap_or_default();
|
let data = recv.read_to_end(256).await.unwrap_or_default();
|
||||||
@@ -269,7 +304,10 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
println!("{} {}", cli.relay_addr, version.trim());
|
println!("{} {}", cli.relay_addr, version.trim());
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("relay {} does not support version query: {e}", cli.relay_addr);
|
eprintln!(
|
||||||
|
"relay {} does not support version query: {e}",
|
||||||
|
cli.relay_addr
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
endpoint.close(0u32.into(), b"done");
|
endpoint.close(0u32.into(), b"done");
|
||||||
@@ -309,8 +347,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
"0.0.0.0:0".parse()?
|
"0.0.0.0:0".parse()?
|
||||||
};
|
};
|
||||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||||
let connection =
|
let connection = wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
||||||
wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
|
||||||
|
|
||||||
info!("Connected to relay");
|
info!("Connected to relay");
|
||||||
|
|
||||||
@@ -321,9 +358,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
{
|
{
|
||||||
let shutdown_transport = transport.clone();
|
let shutdown_transport = transport.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
let mut sigterm =
|
||||||
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||||
.expect("failed to register SIGTERM handler");
|
.expect("failed to register SIGTERM handler");
|
||||||
let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
|
let mut sigint =
|
||||||
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
|
||||||
.expect("failed to register SIGINT handler");
|
.expect("failed to register SIGINT handler");
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); }
|
_ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); }
|
||||||
@@ -332,13 +371,16 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// Close the QUIC connection immediately (APPLICATION_CLOSE frame).
|
// Close the QUIC connection immediately (APPLICATION_CLOSE frame).
|
||||||
// Don't call process::exit — let the main task detect the closed
|
// Don't call process::exit — let the main task detect the closed
|
||||||
// connection and perform clean shutdown (e.g., save recordings).
|
// connection and perform clean shutdown (e.g., save recordings).
|
||||||
shutdown_transport.connection().close(0u32.into(), b"shutdown");
|
shutdown_transport
|
||||||
|
.connection()
|
||||||
|
.close(0u32.into(), b"shutdown");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send auth token if provided (relay with --auth-url expects this first)
|
// Send auth token if provided (relay with --auth-url expects this first)
|
||||||
if let Some(ref token) = cli.token {
|
if let Some(ref token) = cli.token {
|
||||||
let auth = wzp_proto::SignalMessage::AuthToken {
|
let auth = wzp_proto::SignalMessage::AuthToken {
|
||||||
|
version: default_signal_version(),
|
||||||
token: token.clone(),
|
token: token.clone(),
|
||||||
};
|
};
|
||||||
transport.send_signal(&auth).await?;
|
transport.send_signal(&auth).await?;
|
||||||
@@ -346,21 +388,29 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Crypto handshake — establishes verified identity + session key
|
// Crypto handshake — establishes verified identity + session key
|
||||||
let _crypto_session = wzp_client::handshake::perform_handshake(
|
let hs = wzp_client::handshake::perform_handshake(
|
||||||
&*transport,
|
&*transport,
|
||||||
&seed.0,
|
&seed.0,
|
||||||
None, // alias — desktop client doesn't set one yet
|
None, // alias — desktop client doesn't set one yet
|
||||||
).await?;
|
)
|
||||||
info!("crypto handshake complete");
|
.await?;
|
||||||
|
info!(video_codec = ?hs.video_codec, "crypto handshake complete");
|
||||||
|
|
||||||
|
// Wrap the transport so all media I/O goes through AEAD encryption.
|
||||||
|
let enc_transport: Arc<dyn wzp_proto::MediaTransport> = Arc::new(
|
||||||
|
wzp_client::encrypted_transport::EncryptingTransport::new(transport.clone(), hs.session),
|
||||||
|
);
|
||||||
|
|
||||||
if cli.live {
|
if cli.live {
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
{
|
{
|
||||||
return run_live(transport).await;
|
return run_live(enc_transport).await;
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "audio"))]
|
#[cfg(not(feature = "audio"))]
|
||||||
{
|
{
|
||||||
anyhow::bail!("--live requires the 'audio' feature (build with: cargo build --features audio)");
|
anyhow::bail!(
|
||||||
|
"--live requires the 'audio' feature (build with: cargo build --features audio)"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else if let Some(secs) = cli.echo_test_secs {
|
} else if let Some(secs) = cli.echo_test_secs {
|
||||||
let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?;
|
let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?;
|
||||||
@@ -377,14 +427,20 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
transport.close().await?;
|
transport.close().await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
} else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() {
|
} else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() {
|
||||||
run_file_mode(transport, cli.send_tone_secs, cli.send_file, cli.record_file).await
|
run_file_mode(
|
||||||
|
enc_transport,
|
||||||
|
cli.send_tone_secs,
|
||||||
|
cli.send_file,
|
||||||
|
cli.record_file,
|
||||||
|
)
|
||||||
|
.await
|
||||||
} else {
|
} else {
|
||||||
run_silence(transport).await
|
run_silence(enc_transport).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send silence frames (connectivity test).
|
/// Send silence frames (connectivity test).
|
||||||
async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
async fn run_silence(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||||
let config = CallConfig::default();
|
let config = CallConfig::default();
|
||||||
let mut encoder = CallEncoder::new(&config);
|
let mut encoder = CallEncoder::new(&config);
|
||||||
|
|
||||||
@@ -398,7 +454,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
|||||||
for i in 0..250u32 {
|
for i in 0..250u32 {
|
||||||
let packets = encoder.encode_frame(&pcm)?;
|
let packets = encoder.encode_frame(&pcm)?;
|
||||||
for pkt in &packets {
|
for pkt in &packets {
|
||||||
if pkt.header.is_repair {
|
if pkt.header.is_repair() {
|
||||||
total_repair += 1;
|
total_repair += 1;
|
||||||
} else {
|
} else {
|
||||||
total_source += 1;
|
total_source += 1;
|
||||||
@@ -423,6 +479,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
|||||||
|
|
||||||
info!(total_source, total_repair, total_bytes, "done — closing");
|
info!(total_source, total_repair, total_bytes, "done — closing");
|
||||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
};
|
};
|
||||||
@@ -433,7 +490,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
|||||||
|
|
||||||
/// File/tone mode: send a test tone or audio file, and/or record received audio.
|
/// File/tone mode: send a test tone or audio file, and/or record received audio.
|
||||||
async fn run_file_mode(
|
async fn run_file_mode(
|
||||||
transport: Arc<wzp_transport::QuinnTransport>,
|
transport: Arc<dyn wzp_proto::MediaTransport>,
|
||||||
send_tone_secs: Option<u32>,
|
send_tone_secs: Option<u32>,
|
||||||
send_file: Option<String>,
|
send_file: Option<String>,
|
||||||
record_file: Option<String>,
|
record_file: Option<String>,
|
||||||
@@ -448,21 +505,28 @@ async fn run_file_mode(
|
|||||||
// Read raw PCM file (48kHz mono s16le)
|
// Read raw PCM file (48kHz mono s16le)
|
||||||
let bytes = match std::fs::read(path) {
|
let bytes = match std::fs::read(path) {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => { error!("read {path}: {e}"); return; }
|
Err(e) => {
|
||||||
|
error!("read {path}: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let samples: Vec<i16> = bytes.chunks_exact(2)
|
let samples: Vec<i16> = bytes
|
||||||
|
.chunks_exact(2)
|
||||||
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
||||||
.collect();
|
.collect();
|
||||||
let duration = samples.len() as f64 / 48_000.0;
|
let duration = samples.len() as f64 / 48_000.0;
|
||||||
info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file");
|
info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file");
|
||||||
samples.chunks(FRAME_SAMPLES)
|
samples
|
||||||
|
.chunks(FRAME_SAMPLES)
|
||||||
.filter(|c| c.len() == FRAME_SAMPLES)
|
.filter(|c| c.len() == FRAME_SAMPLES)
|
||||||
.map(|c| c.to_vec())
|
.map(|c| c.to_vec())
|
||||||
.collect()
|
.collect()
|
||||||
} else if let Some(secs) = send_tone_secs {
|
} else if let Some(secs) = send_tone_secs {
|
||||||
let total = (secs as u64) * 50;
|
let total = (secs as u64) * 50;
|
||||||
info!(seconds = secs, frames = total, "sending 440Hz tone");
|
info!(seconds = secs, frames = total, "sending 440Hz tone");
|
||||||
(0..total).map(|i| generate_sine_frame(440.0, 48_000, i)).collect()
|
(0..total)
|
||||||
|
.map(|i| generate_sine_frame(440.0, 48_000, i))
|
||||||
|
.collect()
|
||||||
} else {
|
} else {
|
||||||
// No sending, just wait
|
// No sending, just wait
|
||||||
tokio::signal::ctrl_c().await.ok();
|
tokio::signal::ctrl_c().await.ok();
|
||||||
@@ -486,7 +550,7 @@ async fn run_file_mode(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
for pkt in &packets {
|
for pkt in &packets {
|
||||||
if pkt.header.is_repair {
|
if pkt.header.is_repair() {
|
||||||
total_repair += 1;
|
total_repair += 1;
|
||||||
} else {
|
} else {
|
||||||
total_source += 1;
|
total_source += 1;
|
||||||
@@ -534,7 +598,7 @@ async fn run_file_mode(
|
|||||||
result = recv_transport.recv_media() => {
|
result = recv_transport.recv_media() => {
|
||||||
match result {
|
match result {
|
||||||
Ok(Some(pkt)) => {
|
Ok(Some(pkt)) => {
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
if !is_repair {
|
if !is_repair {
|
||||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||||
@@ -575,6 +639,7 @@ async fn run_file_mode(
|
|||||||
|
|
||||||
// Send Hangup signal so the relay knows we're done
|
// Send Hangup signal so the relay knows we're done
|
||||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
};
|
};
|
||||||
@@ -614,7 +679,7 @@ async fn run_file_mode(
|
|||||||
|
|
||||||
/// Live mode: capture from mic, encode, send; receive, decode, play.
|
/// Live mode: capture from mic, encode, send; receive, decode, play.
|
||||||
#[cfg(feature = "audio")]
|
#[cfg(feature = "audio")]
|
||||||
async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
async fn run_live(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||||
use wzp_client::audio_io::{AudioCapture, AudioPlayback};
|
use wzp_client::audio_io::{AudioCapture, AudioPlayback};
|
||||||
|
|
||||||
let capture = AudioCapture::start()?;
|
let capture = AudioCapture::start()?;
|
||||||
@@ -667,7 +732,7 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
|
|||||||
loop {
|
loop {
|
||||||
match recv_transport.recv_media().await {
|
match recv_transport.recv_media().await {
|
||||||
Ok(Some(pkt)) => {
|
Ok(Some(pkt)) => {
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
// Only decode for source packets (1 source = 1 audio frame).
|
// Only decode for source packets (1 source = 1 audio frame).
|
||||||
// Repair packets feed the FEC decoder but don't produce audio.
|
// Repair packets feed the FEC decoder but don't produce audio.
|
||||||
@@ -712,7 +777,7 @@ async fn run_signal_mode(
|
|||||||
token: Option<String>,
|
token: Option<String>,
|
||||||
call_target: Option<String>,
|
call_target: Option<String>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use wzp_proto::SignalMessage;
|
use wzp_proto::{SignalMessage, default_signal_version};
|
||||||
|
|
||||||
let identity = seed.derive_identity();
|
let identity = seed.derive_identity();
|
||||||
let pub_id = identity.public_identity();
|
let pub_id = identity.public_identity();
|
||||||
@@ -734,22 +799,34 @@ async fn run_signal_mode(
|
|||||||
|
|
||||||
// Auth if token provided
|
// Auth if token provided
|
||||||
if let Some(ref tok) = token {
|
if let Some(ref tok) = token {
|
||||||
transport.send_signal(&SignalMessage::AuthToken { token: tok.clone() }).await?;
|
transport
|
||||||
|
.send_signal(&SignalMessage::AuthToken {
|
||||||
|
version: default_signal_version(),
|
||||||
|
token: tok.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register presence (signature not verified in Phase 1)
|
// Register presence (signature not verified in Phase 1)
|
||||||
transport.send_signal(&SignalMessage::RegisterPresence {
|
transport
|
||||||
|
.send_signal(&SignalMessage::RegisterPresence {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
signature: vec![], // Phase 1: not verified
|
signature: vec![], // Phase 1: not verified
|
||||||
alias: None,
|
alias: None,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Wait for ack
|
// Wait for ack
|
||||||
match transport.recv_signal().await? {
|
match transport.recv_signal().await? {
|
||||||
Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => {
|
Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => {
|
||||||
info!(fingerprint = %fp, "registered on relay — waiting for calls");
|
info!(fingerprint = %fp, "registered on relay — waiting for calls");
|
||||||
}
|
}
|
||||||
Some(SignalMessage::RegisterPresenceAck { success: false, error, .. }) => {
|
Some(SignalMessage::RegisterPresenceAck {
|
||||||
|
success: false,
|
||||||
|
error,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
anyhow::bail!("registration failed: {}", error.unwrap_or_default());
|
anyhow::bail!("registration failed: {}", error.unwrap_or_default());
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
@@ -760,10 +837,17 @@ async fn run_signal_mode(
|
|||||||
// If --call specified, place the call
|
// If --call specified, place the call
|
||||||
if let Some(ref target) = call_target {
|
if let Some(ref target) = call_target {
|
||||||
info!(target = %target, "placing direct call...");
|
info!(target = %target, "placing direct call...");
|
||||||
let call_id = format!("{:016x}", std::time::SystemTime::now()
|
let call_id = format!(
|
||||||
.duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
|
"{:016x}",
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos()
|
||||||
|
);
|
||||||
|
|
||||||
transport.send_signal(&SignalMessage::DirectCallOffer {
|
transport
|
||||||
|
.send_signal(&SignalMessage::DirectCallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
caller_fingerprint: fp.clone(),
|
caller_fingerprint: fp.clone(),
|
||||||
caller_alias: None,
|
caller_alias: None,
|
||||||
target_fingerprint: target.clone(),
|
target_fingerprint: target.clone(),
|
||||||
@@ -776,8 +860,10 @@ async fn run_signal_mode(
|
|||||||
// relay-path.
|
// relay-path.
|
||||||
caller_reflexive_addr: None,
|
caller_reflexive_addr: None,
|
||||||
caller_local_addrs: Vec::new(),
|
caller_local_addrs: Vec::new(),
|
||||||
|
caller_mapped_addr: None,
|
||||||
caller_build_version: None,
|
caller_build_version: None,
|
||||||
}).await?;
|
})
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal recv loop — handle incoming signals
|
// Signal recv loop — handle incoming signals
|
||||||
@@ -788,10 +874,15 @@ async fn run_signal_mode(
|
|||||||
loop {
|
loop {
|
||||||
match signal_transport.recv_signal().await {
|
match signal_transport.recv_signal().await {
|
||||||
Ok(Some(msg)) => match msg {
|
Ok(Some(msg)) => match msg {
|
||||||
SignalMessage::CallRinging { call_id } => {
|
SignalMessage::CallRinging { call_id, .. } => {
|
||||||
info!(call_id = %call_id, "ringing...");
|
info!(call_id = %call_id, "ringing...");
|
||||||
}
|
}
|
||||||
SignalMessage::DirectCallOffer { caller_fingerprint, caller_alias, call_id, .. } => {
|
SignalMessage::DirectCallOffer {
|
||||||
|
caller_fingerprint,
|
||||||
|
caller_alias,
|
||||||
|
call_id,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
info!(
|
info!(
|
||||||
from = %caller_fingerprint,
|
from = %caller_fingerprint,
|
||||||
alias = ?caller_alias,
|
alias = ?caller_alias,
|
||||||
@@ -799,7 +890,9 @@ async fn run_signal_mode(
|
|||||||
"incoming call — auto-accepting (generic)"
|
"incoming call — auto-accepting (generic)"
|
||||||
);
|
);
|
||||||
// Auto-accept for CLI testing
|
// Auto-accept for CLI testing
|
||||||
let _ = signal_transport.send_signal(&SignalMessage::DirectCallAnswer {
|
let _ = signal_transport
|
||||||
|
.send_signal(&SignalMessage::DirectCallAnswer {
|
||||||
|
version: default_signal_version(),
|
||||||
call_id,
|
call_id,
|
||||||
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
|
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||||
identity_pub: Some(identity_pub),
|
identity_pub: Some(identity_pub),
|
||||||
@@ -810,13 +903,27 @@ async fn run_signal_mode(
|
|||||||
// so callee addr stays hidden from the caller.
|
// so callee addr stays hidden from the caller.
|
||||||
callee_reflexive_addr: None,
|
callee_reflexive_addr: None,
|
||||||
callee_local_addrs: Vec::new(),
|
callee_local_addrs: Vec::new(),
|
||||||
|
callee_mapped_addr: None,
|
||||||
callee_build_version: None,
|
callee_build_version: None,
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
SignalMessage::DirectCallAnswer { call_id, accept_mode, .. } => {
|
SignalMessage::DirectCallAnswer {
|
||||||
|
call_id,
|
||||||
|
accept_mode,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
info!(call_id = %call_id, mode = ?accept_mode, "call answered");
|
info!(call_id = %call_id, mode = ?accept_mode, "call answered");
|
||||||
}
|
}
|
||||||
SignalMessage::CallSetup { call_id, room, relay_addr: setup_relay, peer_direct_addr: _, peer_local_addrs: _ } => {
|
SignalMessage::CallSetup {
|
||||||
|
call_id,
|
||||||
|
room,
|
||||||
|
relay_addr: setup_relay,
|
||||||
|
peer_direct_addr: _,
|
||||||
|
peer_local_addrs: _,
|
||||||
|
peer_mapped_addr: _,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room");
|
info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room");
|
||||||
|
|
||||||
// Connect to the media room
|
// Connect to the media room
|
||||||
@@ -824,18 +931,28 @@ async fn run_signal_mode(
|
|||||||
let media_cfg = wzp_transport::client_config();
|
let media_cfg = wzp_transport::client_config();
|
||||||
match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await {
|
match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await {
|
||||||
Ok(media_conn) => {
|
Ok(media_conn) => {
|
||||||
let media_transport = Arc::new(wzp_transport::QuinnTransport::new(media_conn));
|
let media_transport =
|
||||||
|
Arc::new(wzp_transport::QuinnTransport::new(media_conn));
|
||||||
|
|
||||||
// Crypto handshake
|
// Crypto handshake
|
||||||
match wzp_client::handshake::perform_handshake(&*media_transport, &my_seed, None).await {
|
match wzp_client::handshake::perform_handshake(
|
||||||
Ok(_session) => {
|
&*media_transport,
|
||||||
info!("media connected — sending tone (press Ctrl+C to hang up)");
|
&my_seed,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(_hs) => {
|
||||||
|
info!(
|
||||||
|
"media connected — sending tone (press Ctrl+C to hang up)"
|
||||||
|
);
|
||||||
|
|
||||||
// Simple tone sender for testing
|
// Simple tone sender for testing
|
||||||
let mt = media_transport.clone();
|
let mt = media_transport.clone();
|
||||||
let send_task = tokio::spawn(async move {
|
let send_task = tokio::spawn(async move {
|
||||||
let config = wzp_client::call::CallConfig::default();
|
let config = wzp_client::call::CallConfig::default();
|
||||||
let mut encoder = wzp_client::call::CallEncoder::new(&config);
|
let mut encoder =
|
||||||
|
wzp_client::call::CallEncoder::new(&config);
|
||||||
let duration = tokio::time::Duration::from_millis(20);
|
let duration = tokio::time::Duration::from_millis(20);
|
||||||
loop {
|
loop {
|
||||||
let pcm: Vec<i16> = (0..FRAME_SAMPLES)
|
let pcm: Vec<i16> = (0..FRAME_SAMPLES)
|
||||||
@@ -843,7 +960,9 @@ async fn run_signal_mode(
|
|||||||
.collect();
|
.collect();
|
||||||
if let Ok(pkts) = encoder.encode_frame(&pcm) {
|
if let Ok(pkts) = encoder.encode_frame(&pcm) {
|
||||||
for pkt in &pkts {
|
for pkt in &pkts {
|
||||||
if mt.send_media(pkt).await.is_err() { return; }
|
if mt.send_media(pkt).await.is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokio::time::sleep(duration).await;
|
tokio::time::sleep(duration).await;
|
||||||
@@ -866,6 +985,7 @@ async fn run_signal_mode(
|
|||||||
_ = tokio::signal::ctrl_c() => {
|
_ = tokio::signal::ctrl_c() => {
|
||||||
info!("hanging up...");
|
info!("hanging up...");
|
||||||
let _ = signal_transport.send_signal(&SignalMessage::Hangup {
|
let _ = signal_transport.send_signal(&SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
}).await;
|
}).await;
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ pub async fn run_drift_test(
|
|||||||
}
|
}
|
||||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||||
Ok(Ok(Some(pkt))) => {
|
Ok(Ok(Some(pkt))) => {
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
if !is_repair {
|
if !is_repair {
|
||||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||||
@@ -180,7 +180,7 @@ pub async fn run_drift_test(
|
|||||||
while Instant::now() < drain_deadline {
|
while Instant::now() < drain_deadline {
|
||||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||||
Ok(Ok(Some(pkt))) => {
|
Ok(Ok(Some(pkt))) => {
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
if !is_repair {
|
if !is_repair {
|
||||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||||
@@ -234,7 +234,10 @@ pub fn print_drift_report(result: &DriftResult) {
|
|||||||
println!();
|
println!();
|
||||||
println!("Expected duration: {} ms", result.expected_duration_ms);
|
println!("Expected duration: {} ms", result.expected_duration_ms);
|
||||||
println!("Actual duration: {} ms", result.actual_duration_ms);
|
println!("Actual duration: {} ms", result.actual_duration_ms);
|
||||||
println!("Drift: {} ms ({:+.4}%)", result.drift_ms, result.drift_pct);
|
println!(
|
||||||
|
"Drift: {} ms ({:+.4}%)",
|
||||||
|
result.drift_ms, result.drift_pct
|
||||||
|
);
|
||||||
println!();
|
println!();
|
||||||
|
|
||||||
// Interpretation
|
// Interpretation
|
||||||
@@ -246,9 +249,15 @@ pub fn print_drift_report(result: &DriftResult) {
|
|||||||
} else if abs_drift < 20 {
|
} else if abs_drift < 20 {
|
||||||
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
|
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
|
||||||
} else if abs_drift < 100 {
|
} else if abs_drift < 100 {
|
||||||
println!("Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.", abs_drift);
|
println!(
|
||||||
|
"Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.",
|
||||||
|
abs_drift
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
println!("Result: POOR -- significant drift ({} ms). Investigate clock sources.", abs_drift);
|
println!(
|
||||||
|
"Result: POOR -- significant drift ({} ms). Investigate clock sources.",
|
||||||
|
abs_drift
|
||||||
|
);
|
||||||
}
|
}
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,15 @@ pub enum WinningPath {
|
|||||||
Relay,
|
Relay,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Diagnostic info for a single candidate dial attempt.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct CandidateDiag {
|
||||||
|
pub index: usize,
|
||||||
|
pub addr: String,
|
||||||
|
pub result: String, // "ok", "skipped:ipv6", "error:..."
|
||||||
|
pub elapsed_ms: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Phase 6: the race now returns BOTH transports (when available)
|
/// Phase 6: the race now returns BOTH transports (when available)
|
||||||
/// so the connect command can negotiate with the peer before
|
/// so the connect command can negotiate with the peer before
|
||||||
/// committing. The negotiation decides which transport to use
|
/// committing. The negotiation decides which transport to use
|
||||||
@@ -54,6 +63,8 @@ pub struct RaceResult {
|
|||||||
/// Informational — the actual path used is decided by the
|
/// Informational — the actual path used is decided by the
|
||||||
/// Phase 6 negotiation after both sides exchange reports.
|
/// Phase 6 negotiation after both sides exchange reports.
|
||||||
pub local_winner: WinningPath,
|
pub local_winner: WinningPath,
|
||||||
|
/// Per-candidate diagnostic info for debugging.
|
||||||
|
pub candidate_diags: Vec<CandidateDiag>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attempt a direct QUIC connection to the peer in parallel with
|
/// Attempt a direct QUIC connection to the peer in parallel with
|
||||||
@@ -88,19 +99,30 @@ pub struct PeerCandidates {
|
|||||||
/// same-LAN pairs — direct dials to these bypass the NAT
|
/// same-LAN pairs — direct dials to these bypass the NAT
|
||||||
/// entirely.
|
/// entirely.
|
||||||
pub local: Vec<SocketAddr>,
|
pub local: Vec<SocketAddr>,
|
||||||
|
/// Phase 8 (Tailscale-inspired): peer's port-mapped external
|
||||||
|
/// address from NAT-PMP/PCP/UPnP. When the router supports
|
||||||
|
/// port mapping, this gives a stable external address even
|
||||||
|
/// behind symmetric NATs.
|
||||||
|
pub mapped: Option<SocketAddr>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PeerCandidates {
|
impl PeerCandidates {
|
||||||
/// Flatten into the list of addrs the D-role should dial.
|
/// Flatten into the list of addrs the D-role should dial.
|
||||||
/// Order: LAN host candidates first (fastest when they
|
/// Order: LAN host candidates first (fastest when they
|
||||||
/// work), then reflexive (covers the non-LAN case).
|
/// work), then port-mapped (stable even behind symmetric
|
||||||
|
/// NATs), then reflexive (covers the non-LAN case).
|
||||||
pub fn dial_order(&self) -> Vec<SocketAddr> {
|
pub fn dial_order(&self) -> Vec<SocketAddr> {
|
||||||
let mut out = Vec::with_capacity(self.local.len() + 1);
|
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||||
out.extend(self.local.iter().copied());
|
out.extend(self.local.iter().copied());
|
||||||
|
// Port-mapped address goes before reflexive — it's
|
||||||
|
// more reliable on symmetric NATs where the reflexive
|
||||||
|
// addr might not match what the peer actually sees.
|
||||||
|
if let Some(a) = self.mapped {
|
||||||
|
if !out.contains(&a) {
|
||||||
|
out.push(a);
|
||||||
|
}
|
||||||
|
}
|
||||||
if let Some(a) = self.reflexive {
|
if let Some(a) = self.reflexive {
|
||||||
// Only add if it's not already in the list (some
|
|
||||||
// edge cases on same-LAN could have the same addr
|
|
||||||
// in both).
|
|
||||||
if !out.contains(&a) {
|
if !out.contains(&a) {
|
||||||
out.push(a);
|
out.push(a);
|
||||||
}
|
}
|
||||||
@@ -108,10 +130,54 @@ impl PeerCandidates {
|
|||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Smart dial order: filters out candidates that can't possibly
|
||||||
|
/// work given our own reflexive address.
|
||||||
|
///
|
||||||
|
/// - **LAN candidates**: only included if peer's public IP
|
||||||
|
/// matches ours (same network). Private IPs are unreachable
|
||||||
|
/// cross-network.
|
||||||
|
/// - **IPv6 candidates**: stripped entirely (Phase 7 disabled).
|
||||||
|
/// - **Reflexive + mapped**: always included.
|
||||||
|
pub fn smart_dial_order(&self, own_reflexive: Option<&SocketAddr>) -> Vec<SocketAddr> {
|
||||||
|
let own_public_ip = own_reflexive.map(|a| a.ip());
|
||||||
|
let peer_public_ip = self.reflexive.map(|a| a.ip());
|
||||||
|
let same_network = match (own_public_ip, peer_public_ip) {
|
||||||
|
(Some(a), Some(b)) => a == b,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||||
|
|
||||||
|
// LAN candidates only when on the same network.
|
||||||
|
if same_network {
|
||||||
|
for addr in &self.local {
|
||||||
|
if !addr.is_ipv6() {
|
||||||
|
out.push(*addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port-mapped (always useful — it's a public addr).
|
||||||
|
if let Some(a) = self.mapped {
|
||||||
|
if !a.is_ipv6() && !out.contains(&a) {
|
||||||
|
out.push(a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reflexive (always useful — it's the peer's public addr).
|
||||||
|
if let Some(a) = self.reflexive {
|
||||||
|
if !a.is_ipv6() && !out.contains(&a) {
|
||||||
|
out.push(a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// Is there anything for the D-role to dial? If not, the
|
/// Is there anything for the D-role to dial? If not, the
|
||||||
/// race reduces to relay-only.
|
/// race reduces to relay-only.
|
||||||
pub fn is_empty(&self) -> bool {
|
pub fn is_empty(&self) -> bool {
|
||||||
self.reflexive.is_none() && self.local.is_empty()
|
self.reflexive.is_none() && self.local.is_empty() && self.mapped.is_none()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,6 +188,9 @@ pub async fn race(
|
|||||||
relay_addr: SocketAddr,
|
relay_addr: SocketAddr,
|
||||||
room_sni: String,
|
room_sni: String,
|
||||||
call_sni: String,
|
call_sni: String,
|
||||||
|
// Our own reflexive address — used to filter LAN candidates
|
||||||
|
// that can't work cross-network.
|
||||||
|
own_reflexive: Option<SocketAddr>,
|
||||||
// Phase 5: when `Some`, reuse this endpoint for BOTH the
|
// Phase 5: when `Some`, reuse this endpoint for BOTH the
|
||||||
// direct-path branch AND the relay dial. Pass the signal
|
// direct-path branch AND the relay dial. Pass the signal
|
||||||
// endpoint. The endpoint MUST be server-capable (created
|
// endpoint. The endpoint MUST be server-capable (created
|
||||||
@@ -141,6 +210,10 @@ pub async fn race(
|
|||||||
// is created. Install attempt is idempotent.
|
// is created. Install attempt is idempotent.
|
||||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
|
||||||
|
// Shared diagnostic collector for per-candidate results.
|
||||||
|
let diags_collector: Arc<std::sync::Mutex<Vec<CandidateDiag>>> =
|
||||||
|
Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||||
|
|
||||||
// Build the direct-path endpoint + future based on role.
|
// Build the direct-path endpoint + future based on role.
|
||||||
//
|
//
|
||||||
// A-role: one accept future on the shared endpoint. The
|
// A-role: one accept future on the shared endpoint. The
|
||||||
@@ -196,7 +269,84 @@ pub async fn race(
|
|||||||
// as dial — IPv6 connections die on datagram send).
|
// as dial — IPv6 connections die on datagram send).
|
||||||
// Accept on IPv4 shared endpoint only.
|
// Accept on IPv4 shared endpoint only.
|
||||||
let _v6_ep_unused = ipv6_endpoint.clone();
|
let _v6_ep_unused = ipv6_endpoint.clone();
|
||||||
|
// Collect peer addrs for NAT tickle (Acceptor-side).
|
||||||
|
let tickle_addrs: Vec<SocketAddr> = peer_candidates
|
||||||
|
.smart_dial_order(own_reflexive.as_ref())
|
||||||
|
.into_iter()
|
||||||
|
.filter(|a| !a.ip().is_loopback() && !a.ip().is_unspecified())
|
||||||
|
.collect();
|
||||||
direct_fut = Box::pin(async move {
|
direct_fut = Box::pin(async move {
|
||||||
|
// NAT tickle: send a small UDP packet to each of the
|
||||||
|
// Dialer's candidate addresses FROM our shared endpoint.
|
||||||
|
// This opens our NAT's pinhole for return traffic from
|
||||||
|
// those IPs — critical for address-restricted NATs that
|
||||||
|
// only allow inbound from IPs they've seen outbound
|
||||||
|
// traffic to. Without this, the Dialer's QUIC Initial
|
||||||
|
// gets dropped by our NAT.
|
||||||
|
if !tickle_addrs.is_empty() {
|
||||||
|
if let Ok(local_addr) = ep_for_fut.local_addr() {
|
||||||
|
// Send a tickle to each peer candidate address
|
||||||
|
// to open our NAT for return traffic from that IP.
|
||||||
|
//
|
||||||
|
// We use a socket2 socket with SO_REUSEADDR +
|
||||||
|
// SO_REUSEPORT on the SAME port as the quinn
|
||||||
|
// endpoint. This is necessary because quinn
|
||||||
|
// already holds the port — a plain bind() would
|
||||||
|
// fail with EADDRINUSE.
|
||||||
|
let tickle_result: Result<(), String> = (|| {
|
||||||
|
use std::net::UdpSocket as StdUdpSocket;
|
||||||
|
let sock = socket2::Socket::new(
|
||||||
|
socket2::Domain::IPV4,
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)
|
||||||
|
.map_err(|e| format!("socket: {e}"))?;
|
||||||
|
sock.set_reuse_address(true)
|
||||||
|
.map_err(|e| format!("reuseaddr: {e}"))?;
|
||||||
|
// macOS/BSD/Linux also need SO_REUSEPORT
|
||||||
|
#[cfg(any(
|
||||||
|
target_os = "macos",
|
||||||
|
target_os = "linux",
|
||||||
|
target_os = "android"
|
||||||
|
))]
|
||||||
|
{
|
||||||
|
// socket2 exposes set_reuse_port on unix
|
||||||
|
unsafe {
|
||||||
|
let optval: libc::c_int = 1;
|
||||||
|
libc::setsockopt(
|
||||||
|
std::os::unix::io::AsRawFd::as_raw_fd(&sock),
|
||||||
|
libc::SOL_SOCKET,
|
||||||
|
libc::SO_REUSEPORT,
|
||||||
|
&optval as *const _ as *const libc::c_void,
|
||||||
|
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sock.set_nonblocking(true)
|
||||||
|
.map_err(|e| format!("nonblock: {e}"))?;
|
||||||
|
let bind_addr: SocketAddr = SocketAddr::new(
|
||||||
|
std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
|
||||||
|
local_addr.port(),
|
||||||
|
);
|
||||||
|
sock.bind(&bind_addr.into())
|
||||||
|
.map_err(|e| format!("bind :{}: {e}", local_addr.port()))?;
|
||||||
|
let std_sock: StdUdpSocket = sock.into();
|
||||||
|
for addr in &tickle_addrs {
|
||||||
|
let _ = std_sock.send_to(&[0u8; 1], addr);
|
||||||
|
tracing::info!(
|
||||||
|
%addr,
|
||||||
|
local_port = local_addr.port(),
|
||||||
|
"dual_path: A-role sent NAT tickle"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})();
|
||||||
|
if let Err(e) = tickle_result {
|
||||||
|
tracing::warn!(error = %e, "dual_path: A-role NAT tickle failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Accept loop: retry if we get a stale/closed
|
// Accept loop: retry if we get a stale/closed
|
||||||
// connection from a previous call. Max 3 retries
|
// connection from a previous call. Max 3 retries
|
||||||
// to avoid spinning until the race timeout.
|
// to avoid spinning until the race timeout.
|
||||||
@@ -270,8 +420,9 @@ pub async fn race(
|
|||||||
};
|
};
|
||||||
let ep_for_fut = ep.clone();
|
let ep_for_fut = ep.clone();
|
||||||
let _v6_ep_for_dial = ipv6_endpoint.clone();
|
let _v6_ep_for_dial = ipv6_endpoint.clone();
|
||||||
let dial_order = peer_candidates.dial_order();
|
let dial_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||||
let sni = call_sni.clone();
|
let sni = call_sni.clone();
|
||||||
|
let diags = diags_collector.clone();
|
||||||
direct_fut = Box::pin(async move {
|
direct_fut = Box::pin(async move {
|
||||||
if dial_order.is_empty() {
|
if dial_order.is_empty() {
|
||||||
// No candidates — the race reduces to
|
// No candidates — the race reduces to
|
||||||
@@ -300,24 +451,47 @@ pub async fn race(
|
|||||||
// Re-enable once IPv6 datagram delivery is
|
// Re-enable once IPv6 datagram delivery is
|
||||||
// verified on target networks.
|
// verified on target networks.
|
||||||
if candidate.is_ipv6() {
|
if candidate.is_ipv6() {
|
||||||
tracing::debug!(
|
tracing::info!(
|
||||||
%candidate,
|
%candidate,
|
||||||
candidate_idx = idx,
|
candidate_idx = idx,
|
||||||
"dual_path: skipping IPv6 candidate (disabled)"
|
"dual_path: skipping IPv6 candidate (disabled)"
|
||||||
);
|
);
|
||||||
|
if let Ok(mut d) = diags.lock() {
|
||||||
|
d.push(CandidateDiag {
|
||||||
|
index: idx,
|
||||||
|
addr: candidate.to_string(),
|
||||||
|
result: "skipped:ipv6".into(),
|
||||||
|
elapsed_ms: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let ep = ep_for_fut.clone();
|
let ep = ep_for_fut.clone();
|
||||||
let client_cfg = wzp_transport::client_config();
|
let client_cfg = wzp_transport::client_config();
|
||||||
let sni = sni.clone();
|
let sni = sni.clone();
|
||||||
|
let diags_inner = diags.clone();
|
||||||
set.spawn(async move {
|
set.spawn(async move {
|
||||||
let result = wzp_transport::connect(
|
let start = std::time::Instant::now();
|
||||||
&ep,
|
tracing::info!(
|
||||||
candidate,
|
%candidate,
|
||||||
&sni,
|
candidate_idx = idx,
|
||||||
client_cfg,
|
"dual_path: dialing candidate"
|
||||||
)
|
);
|
||||||
.await;
|
let result =
|
||||||
|
wzp_transport::connect(&ep, candidate, &sni, client_cfg).await;
|
||||||
|
let elapsed = start.elapsed().as_millis() as u32;
|
||||||
|
let diag_result = match &result {
|
||||||
|
Ok(_) => "ok".to_string(),
|
||||||
|
Err(e) => format!("error:{e}"),
|
||||||
|
};
|
||||||
|
if let Ok(mut d) = diags_inner.lock() {
|
||||||
|
d.push(CandidateDiag {
|
||||||
|
index: idx,
|
||||||
|
addr: candidate.to_string(),
|
||||||
|
result: diag_result,
|
||||||
|
elapsed_ms: Some(elapsed),
|
||||||
|
});
|
||||||
|
}
|
||||||
(idx, candidate, result)
|
(idx, candidate, result)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -346,7 +520,7 @@ pub async fn race(
|
|||||||
return Ok(QuinnTransport::new(conn));
|
return Ok(QuinnTransport::new(conn));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::debug!(
|
tracing::info!(
|
||||||
%candidate,
|
%candidate,
|
||||||
candidate_idx = idx,
|
candidate_idx = idx,
|
||||||
error = %e,
|
error = %e,
|
||||||
@@ -423,16 +597,17 @@ pub async fn race(
|
|||||||
// RaceResult with both transports (when available) and uses the
|
// RaceResult with both transports (when available) and uses the
|
||||||
// Phase 6 MediaPathReport exchange to decide which one to
|
// Phase 6 MediaPathReport exchange to decide which one to
|
||||||
// actually use for media.
|
// actually use for media.
|
||||||
|
let smart_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
?role,
|
?role,
|
||||||
candidates = ?peer_candidates.dial_order(),
|
raw_candidates = ?peer_candidates.dial_order(),
|
||||||
|
filtered_candidates = ?smart_order,
|
||||||
|
?own_reflexive,
|
||||||
%relay_addr,
|
%relay_addr,
|
||||||
"dual_path: racing direct vs relay"
|
"dual_path: racing direct vs relay"
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut direct_task = tokio::spawn(
|
let mut direct_task = tokio::spawn(tokio::time::timeout(Duration::from_secs(4), direct_fut));
|
||||||
tokio::time::timeout(Duration::from_secs(2), direct_fut),
|
|
||||||
);
|
|
||||||
let mut relay_task = tokio::spawn(async move {
|
let mut relay_task = tokio::spawn(async move {
|
||||||
// Keep the 500ms head start so direct has a chance
|
// Keep the 500ms head start so direct has a chance
|
||||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
@@ -464,9 +639,25 @@ pub async fn race(
|
|||||||
local_winner = WinningPath::Relay; // direct failed → relay is our only hope
|
local_winner = WinningPath::Relay; // direct failed → relay is our only hope
|
||||||
}
|
}
|
||||||
Ok(Err(_)) => {
|
Ok(Err(_)) => {
|
||||||
tracing::warn!("dual_path: direct timed out (2s)");
|
tracing::warn!("dual_path: direct timed out (4s)");
|
||||||
direct_result = Some(Err(anyhow::anyhow!("direct timeout")));
|
direct_result = Some(Err(anyhow::anyhow!("direct timeout")));
|
||||||
local_winner = WinningPath::Relay;
|
local_winner = WinningPath::Relay;
|
||||||
|
// Record timeout diag for candidates that were
|
||||||
|
// still in-flight when the timeout fired.
|
||||||
|
if let Ok(mut d) = diags_collector.lock() {
|
||||||
|
let recorded_indices: std::collections::HashSet<usize> =
|
||||||
|
d.iter().map(|diag| diag.index).collect();
|
||||||
|
for (idx, addr) in smart_order.iter().enumerate() {
|
||||||
|
if !recorded_indices.contains(&idx) {
|
||||||
|
d.push(CandidateDiag {
|
||||||
|
index: idx,
|
||||||
|
addr: addr.to_string(),
|
||||||
|
result: "timeout:4s".into(),
|
||||||
|
elapsed_ms: Some(4000),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "dual_path: direct task panicked");
|
tracing::warn!(error = %e, "dual_path: direct task panicked");
|
||||||
@@ -505,16 +696,43 @@ pub async fn race(
|
|||||||
// If it doesn't, we still proceed with just the winner.
|
// If it doesn't, we still proceed with just the winner.
|
||||||
if direct_result.is_none() {
|
if direct_result.is_none() {
|
||||||
match tokio::time::timeout(Duration::from_secs(1), direct_task).await {
|
match tokio::time::timeout(Duration::from_secs(1), direct_task).await {
|
||||||
Ok(Ok(Ok(Ok(t)))) => { direct_result = Some(Ok(t)); }
|
Ok(Ok(Ok(Ok(t)))) => {
|
||||||
Ok(Ok(Ok(Err(e)))) => { direct_result = Some(Err(anyhow::anyhow!("{e}"))); }
|
direct_result = Some(Ok(t));
|
||||||
_ => { direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period"))); }
|
}
|
||||||
|
Ok(Ok(Ok(Err(e)))) => {
|
||||||
|
direct_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period")));
|
||||||
|
// Fill timeout diags for candidates that never reported.
|
||||||
|
if let Ok(mut d) = diags_collector.lock() {
|
||||||
|
let recorded: std::collections::HashSet<usize> =
|
||||||
|
d.iter().map(|diag| diag.index).collect();
|
||||||
|
for (idx, addr) in smart_order.iter().enumerate() {
|
||||||
|
if !recorded.contains(&idx) {
|
||||||
|
d.push(CandidateDiag {
|
||||||
|
index: idx,
|
||||||
|
addr: addr.to_string(),
|
||||||
|
result: "timeout:grace".into(),
|
||||||
|
elapsed_ms: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if relay_result.is_none() {
|
if relay_result.is_none() {
|
||||||
match tokio::time::timeout(Duration::from_secs(1), relay_task).await {
|
match tokio::time::timeout(Duration::from_secs(1), relay_task).await {
|
||||||
Ok(Ok(Ok(Ok(t)))) => { relay_result = Some(Ok(t)); }
|
Ok(Ok(Ok(Ok(t)))) => {
|
||||||
Ok(Ok(Ok(Err(e)))) => { relay_result = Some(Err(anyhow::anyhow!("{e}"))); }
|
relay_result = Some(Ok(t));
|
||||||
_ => { relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period"))); }
|
}
|
||||||
|
Ok(Ok(Ok(Err(e)))) => {
|
||||||
|
relay_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -529,18 +747,230 @@ pub async fn race(
|
|||||||
);
|
);
|
||||||
|
|
||||||
if !direct_ok && !relay_ok {
|
if !direct_ok && !relay_ok {
|
||||||
return Err(anyhow::anyhow!("both paths failed: no media transport available"));
|
return Err(anyhow::anyhow!(
|
||||||
|
"both paths failed: no media transport available"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = (direct_ep, relay_ep, ipv6_endpoint);
|
let _ = (direct_ep, relay_ep, ipv6_endpoint);
|
||||||
|
|
||||||
|
let candidate_diags = diags_collector
|
||||||
|
.lock()
|
||||||
|
.map(|d| d.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
Ok(RaceResult {
|
Ok(RaceResult {
|
||||||
direct_transport: direct_result
|
direct_transport: direct_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||||
.and_then(|r| r.ok())
|
relay_transport: relay_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||||
.map(|t| Arc::new(t)),
|
|
||||||
relay_transport: relay_result
|
|
||||||
.and_then(|r| r.ok())
|
|
||||||
.map(|t| Arc::new(t)),
|
|
||||||
local_winner,
|
local_winner,
|
||||||
|
candidate_diags,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_dial_order_all_types() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||||
|
local: vec![
|
||||||
|
"192.168.1.10:4433".parse().unwrap(),
|
||||||
|
"10.0.0.5:4433".parse().unwrap(),
|
||||||
|
],
|
||||||
|
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let order = candidates.dial_order();
|
||||||
|
// Order: local first, then mapped, then reflexive
|
||||||
|
assert_eq!(order.len(), 4);
|
||||||
|
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||||
|
assert_eq!(order[1], "10.0.0.5:4433".parse::<SocketAddr>().unwrap());
|
||||||
|
assert_eq!(
|
||||||
|
order[2],
|
||||||
|
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(order[3], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_dial_order_no_mapped() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||||
|
local: vec!["192.168.1.10:4433".parse().unwrap()],
|
||||||
|
mapped: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let order = candidates.dial_order();
|
||||||
|
assert_eq!(order.len(), 2);
|
||||||
|
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||||
|
assert_eq!(order[1], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_dial_order_only_mapped() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: None,
|
||||||
|
local: vec![],
|
||||||
|
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let order = candidates.dial_order();
|
||||||
|
assert_eq!(order.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
order[0],
|
||||||
|
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_dial_order_dedup_mapped_equals_reflexive() {
|
||||||
|
let addr: SocketAddr = "203.0.113.5:4433".parse().unwrap();
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some(addr),
|
||||||
|
local: vec![],
|
||||||
|
mapped: Some(addr), // same as reflexive
|
||||||
|
};
|
||||||
|
|
||||||
|
let order = candidates.dial_order();
|
||||||
|
// Should be deduped to 1
|
||||||
|
assert_eq!(order.len(), 1);
|
||||||
|
assert_eq!(order[0], addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_dial_order_dedup_mapped_in_local() {
|
||||||
|
let addr: SocketAddr = "192.168.1.10:4433".parse().unwrap();
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: None,
|
||||||
|
local: vec![addr],
|
||||||
|
mapped: Some(addr), // same as a local addr
|
||||||
|
};
|
||||||
|
|
||||||
|
let order = candidates.dial_order();
|
||||||
|
assert_eq!(order.len(), 1);
|
||||||
|
assert_eq!(order[0], addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_is_empty() {
|
||||||
|
let empty = PeerCandidates::default();
|
||||||
|
assert!(empty.is_empty());
|
||||||
|
|
||||||
|
let with_reflexive = PeerCandidates {
|
||||||
|
reflexive: Some("1.2.3.4:5".parse().unwrap()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert!(!with_reflexive.is_empty());
|
||||||
|
|
||||||
|
let with_local = PeerCandidates {
|
||||||
|
local: vec!["10.0.0.1:5".parse().unwrap()],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert!(!with_local.is_empty());
|
||||||
|
|
||||||
|
let with_mapped = PeerCandidates {
|
||||||
|
mapped: Some("1.2.3.4:5".parse().unwrap()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert!(!with_mapped.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn peer_candidates_empty_dial_order() {
|
||||||
|
let empty = PeerCandidates::default();
|
||||||
|
assert!(empty.dial_order().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn winning_path_debug() {
|
||||||
|
// Just verify Debug impl doesn't panic
|
||||||
|
let _ = format!("{:?}", WinningPath::Direct);
|
||||||
|
let _ = format!("{:?}", WinningPath::Relay);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── smart_dial_order tests ─────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_dial_order_same_network_includes_lan() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||||
|
local: vec![
|
||||||
|
"192.168.1.10:4433".parse().unwrap(),
|
||||||
|
"10.0.0.5:4433".parse().unwrap(),
|
||||||
|
],
|
||||||
|
mapped: None,
|
||||||
|
};
|
||||||
|
let own: SocketAddr = "203.0.113.5:12345".parse().unwrap();
|
||||||
|
let order = candidates.smart_dial_order(Some(&own));
|
||||||
|
// Same public IP → LAN candidates included
|
||||||
|
assert!(order.contains(&"192.168.1.10:4433".parse().unwrap()));
|
||||||
|
assert!(order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||||
|
assert!(order.contains(&"203.0.113.5:4433".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_dial_order_different_network_strips_lan() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||||
|
local: vec![
|
||||||
|
"172.16.81.126:4433".parse().unwrap(),
|
||||||
|
"10.0.0.5:4433".parse().unwrap(),
|
||||||
|
],
|
||||||
|
mapped: None,
|
||||||
|
};
|
||||||
|
// Different public IP → LAN candidates stripped
|
||||||
|
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||||
|
let order = candidates.smart_dial_order(Some(&own));
|
||||||
|
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||||
|
assert!(!order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||||
|
// Reflexive still included
|
||||||
|
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_dial_order_strips_ipv6() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||||
|
local: vec![
|
||||||
|
"[2a0d:3344:692c::1]:4433".parse().unwrap(),
|
||||||
|
"172.16.81.126:4433".parse().unwrap(),
|
||||||
|
],
|
||||||
|
mapped: None,
|
||||||
|
};
|
||||||
|
// Same network, but IPv6 should be stripped
|
||||||
|
let own: SocketAddr = "150.228.49.65:5555".parse().unwrap();
|
||||||
|
let order = candidates.smart_dial_order(Some(&own));
|
||||||
|
assert!(!order.iter().any(|a| a.is_ipv6()));
|
||||||
|
assert!(order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_dial_order_no_own_reflexive_strips_lan() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||||
|
local: vec!["172.16.81.126:4433".parse().unwrap()],
|
||||||
|
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||||
|
};
|
||||||
|
// No own reflexive → can't determine same network → strip LAN
|
||||||
|
let order = candidates.smart_dial_order(None);
|
||||||
|
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||||
|
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||||
|
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_dial_order_mapped_always_included() {
|
||||||
|
let candidates = PeerCandidates {
|
||||||
|
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||||
|
local: vec![],
|
||||||
|
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||||
|
};
|
||||||
|
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||||
|
let order = candidates.smart_dial_order(Some(&own));
|
||||||
|
assert_eq!(order.len(), 2); // mapped + reflexive
|
||||||
|
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||||
|
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ pub async fn run_echo_test(
|
|||||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||||
Ok(Ok(Some(pkt))) => {
|
Ok(Ok(Some(pkt))) => {
|
||||||
total_packets_received += 1;
|
total_packets_received += 1;
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
if !is_repair {
|
if !is_repair {
|
||||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||||
@@ -184,7 +184,8 @@ pub async fn run_echo_test(
|
|||||||
let time_offset = start.elapsed().as_secs_f64();
|
let time_offset = start.elapsed().as_secs_f64();
|
||||||
|
|
||||||
// Compare sent vs received for this window
|
// Compare sent vs received for this window
|
||||||
let sent_start = (window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
let sent_start =
|
||||||
|
(window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
||||||
let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES);
|
let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES);
|
||||||
let sent_window = if sent_end <= sent_pcm.len() {
|
let sent_window = if sent_end <= sent_pcm.len() {
|
||||||
&sent_pcm[sent_start..sent_end]
|
&sent_pcm[sent_start..sent_end]
|
||||||
@@ -192,7 +193,9 @@ pub async fn run_echo_test(
|
|||||||
&sent_pcm[sent_start..]
|
&sent_pcm[sent_start..]
|
||||||
};
|
};
|
||||||
|
|
||||||
let recv_start = recv_pcm.len().saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
let recv_start = recv_pcm
|
||||||
|
.len()
|
||||||
|
.saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
||||||
let recv_window = &recv_pcm[recv_start..];
|
let recv_window = &recv_pcm[recv_start..];
|
||||||
|
|
||||||
let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0);
|
let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0);
|
||||||
@@ -256,7 +259,7 @@ pub async fn run_echo_test(
|
|||||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||||
Ok(Ok(Some(pkt))) => {
|
Ok(Ok(Some(pkt))) => {
|
||||||
total_packets_received += 1;
|
total_packets_received += 1;
|
||||||
let is_repair = pkt.header.is_repair;
|
let is_repair = pkt.header.is_repair();
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
if !is_repair {
|
if !is_repair {
|
||||||
decoder.decode_next(&mut pcm_buf);
|
decoder.decode_next(&mut pcm_buf);
|
||||||
@@ -310,8 +313,14 @@ pub fn print_report(result: &EchoTestResult) {
|
|||||||
let status = if w.is_silent { " !" } else { " " };
|
let status = if w.is_silent { " !" } else { " " };
|
||||||
println!(
|
println!(
|
||||||
"│ {:>3}{} │ {:>5.1}s │ {:>4} │ {:>4} │ {:>5.1}% │ {:>5.1} │ {:.3} │",
|
"│ {:>3}{} │ {:>5.1}s │ {:>4} │ {:>4} │ {:>5.1}% │ {:>5.1} │ {:.3} │",
|
||||||
w.index, status, w.time_offset_secs, w.frames_sent, w.frames_received,
|
w.index,
|
||||||
w.loss_pct, w.snr_db, w.correlation
|
status,
|
||||||
|
w.time_offset_secs,
|
||||||
|
w.frames_sent,
|
||||||
|
w.frames_received,
|
||||||
|
w.loss_pct,
|
||||||
|
w.snr_db,
|
||||||
|
w.correlation
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘");
|
println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘");
|
||||||
@@ -321,18 +330,28 @@ pub fn print_report(result: &EchoTestResult) {
|
|||||||
let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec();
|
let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec();
|
||||||
let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec();
|
let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec();
|
||||||
|
|
||||||
let avg_loss_first = first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
let avg_loss_first =
|
||||||
let avg_loss_second = second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
||||||
let avg_corr_first = first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
let avg_loss_second =
|
||||||
let avg_corr_second = second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
||||||
|
let avg_corr_first =
|
||||||
|
first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
||||||
|
let avg_corr_second =
|
||||||
|
second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
||||||
|
|
||||||
println!();
|
println!();
|
||||||
if avg_loss_second > avg_loss_first + 5.0 {
|
if avg_loss_second > avg_loss_first + 5.0 {
|
||||||
println!("WARNING: Quality degradation detected!");
|
println!("WARNING: Quality degradation detected!");
|
||||||
println!(" Loss increased from {:.1}% to {:.1}% over time", avg_loss_first, avg_loss_second);
|
println!(
|
||||||
|
" Loss increased from {:.1}% to {:.1}% over time",
|
||||||
|
avg_loss_first, avg_loss_second
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if avg_corr_second < avg_corr_first - 0.1 {
|
if avg_corr_second < avg_corr_first - 0.1 {
|
||||||
println!("WARNING: Signal correlation dropped from {:.3} to {:.3}", avg_corr_first, avg_corr_second);
|
println!(
|
||||||
|
"WARNING: Signal correlation dropped from {:.3} to {:.3}",
|
||||||
|
avg_corr_first, avg_corr_second
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 {
|
if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 {
|
||||||
println!("Quality is STABLE over the test duration.");
|
println!("Quality is STABLE over the test duration.");
|
||||||
|
|||||||
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
//! `EncryptingTransport` — wraps any `MediaTransport` with a `CryptoSession`.
|
||||||
|
//!
|
||||||
|
//! All outbound `send_media` calls encrypt the payload before handing off to
|
||||||
|
//! the inner transport; all inbound `recv_media` calls decrypt after receiving.
|
||||||
|
//! Signal, quality, and close are forwarded unchanged.
|
||||||
|
//!
|
||||||
|
//! The quality report travels in plaintext so the relay can make QoS decisions
|
||||||
|
//! without being able to decrypt media content.
|
||||||
|
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use wzp_proto::{
|
||||||
|
CryptoSession, MediaHeader, MediaPacket, MediaTransport, PathQuality, SignalMessage,
|
||||||
|
TransportError,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wraps a `MediaTransport` and applies AEAD encryption/decryption to media payloads.
|
||||||
|
pub struct EncryptingTransport {
|
||||||
|
inner: Arc<dyn MediaTransport>,
|
||||||
|
session: Mutex<Box<dyn CryptoSession>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncryptingTransport {
|
||||||
|
pub fn new(inner: Arc<dyn MediaTransport>, session: Box<dyn CryptoSession>) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
session: Mutex::new(session),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MediaTransport for EncryptingTransport {
|
||||||
|
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||||
|
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||||
|
packet.header.write_to(&mut header_bytes);
|
||||||
|
|
||||||
|
let mut ciphertext = Vec::new();
|
||||||
|
self.session
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.encrypt(&header_bytes, &packet.payload, &mut ciphertext)
|
||||||
|
.map_err(|e| TransportError::Internal(format!("encrypt: {e}")))?;
|
||||||
|
|
||||||
|
let encrypted = MediaPacket {
|
||||||
|
header: packet.header,
|
||||||
|
payload: Bytes::from(ciphertext),
|
||||||
|
quality_report: packet.quality_report.clone(),
|
||||||
|
};
|
||||||
|
self.inner.send_media(&encrypted).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||||
|
let packet = match self.inner.recv_media().await? {
|
||||||
|
Some(p) => p,
|
||||||
|
None => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||||
|
packet.header.write_to(&mut header_bytes);
|
||||||
|
|
||||||
|
let mut plaintext = Vec::new();
|
||||||
|
self.session
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.decrypt(&header_bytes, &packet.payload, &mut plaintext)
|
||||||
|
.map_err(|e| TransportError::Internal(format!("decrypt: {e}")))?;
|
||||||
|
|
||||||
|
Ok(Some(MediaPacket {
|
||||||
|
header: packet.header,
|
||||||
|
payload: Bytes::from(plaintext),
|
||||||
|
quality_report: packet.quality_report,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_signal(&self, msg: &SignalMessage) -> Result<(), TransportError> {
|
||||||
|
self.inner.send_signal(msg).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||||
|
self.inner.recv_signal().await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path_quality(&self) -> PathQuality {
|
||||||
|
self.inner.path_quality()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), TransportError> {
|
||||||
|
self.inner.close().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
use wzp_crypto::ChaChaSession;
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
|
||||||
|
struct LoopbackTransport {
|
||||||
|
sent: StdMutex<Vec<MediaPacket>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoopbackTransport {
|
||||||
|
fn new() -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
sent: StdMutex::new(Vec::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn take_sent(&self) -> Vec<MediaPacket> {
|
||||||
|
self.sent.lock().unwrap().drain(..).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MediaTransport for LoopbackTransport {
|
||||||
|
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||||
|
self.sent.lock().unwrap().push(packet.clone());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
async fn send_signal(&self, _msg: &SignalMessage) -> Result<(), TransportError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
fn path_quality(&self) -> PathQuality {
|
||||||
|
PathQuality::default()
|
||||||
|
}
|
||||||
|
async fn close(&self) -> Result<(), TransportError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_header(seq: u32) -> MediaHeader {
|
||||||
|
MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq,
|
||||||
|
timestamp: seq * 20,
|
||||||
|
fec_block: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn payload_is_encrypted_on_wire() {
|
||||||
|
let key = [0x42u8; 32];
|
||||||
|
let session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||||
|
let loopback = LoopbackTransport::new();
|
||||||
|
let enc = EncryptingTransport::new(loopback.clone(), session);
|
||||||
|
|
||||||
|
let header = make_header(1);
|
||||||
|
let plaintext = b"secret audio frame";
|
||||||
|
let pkt = MediaPacket {
|
||||||
|
header,
|
||||||
|
payload: Bytes::from_static(plaintext),
|
||||||
|
quality_report: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
enc.send_media(&pkt).await.unwrap();
|
||||||
|
|
||||||
|
let sent = loopback.take_sent();
|
||||||
|
assert_eq!(sent.len(), 1);
|
||||||
|
assert_eq!(sent[0].header, header, "header must be preserved");
|
||||||
|
assert_ne!(
|
||||||
|
sent[0].payload.as_ref(),
|
||||||
|
plaintext.as_ref(),
|
||||||
|
"plaintext must not appear on wire"
|
||||||
|
);
|
||||||
|
// Ciphertext is longer by exactly the AEAD tag (16 bytes)
|
||||||
|
assert_eq!(sent[0].payload.len(), plaintext.len() + 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn encrypt_then_decrypt_roundtrip() {
|
||||||
|
let key = [0x42u8; 32];
|
||||||
|
let send_session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||||
|
let mut recv_session = ChaChaSession::new(key);
|
||||||
|
|
||||||
|
let loopback = LoopbackTransport::new();
|
||||||
|
let enc = EncryptingTransport::new(loopback.clone(), send_session);
|
||||||
|
|
||||||
|
let header = make_header(5);
|
||||||
|
let plaintext = b"hello encrypted world";
|
||||||
|
let pkt = MediaPacket {
|
||||||
|
header,
|
||||||
|
payload: Bytes::from_static(plaintext),
|
||||||
|
quality_report: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
enc.send_media(&pkt).await.unwrap();
|
||||||
|
|
||||||
|
let sent = loopback.take_sent();
|
||||||
|
let wire_pkt = &sent[0];
|
||||||
|
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
header.write_to(&mut header_bytes);
|
||||||
|
let mut decrypted = Vec::new();
|
||||||
|
recv_session
|
||||||
|
.decrypt(&header_bytes, &wire_pkt.payload, &mut decrypted)
|
||||||
|
.expect("decrypt should succeed with matching key");
|
||||||
|
assert_eq!(&decrypted[..], plaintext);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -99,14 +99,15 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
|||||||
SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry)
|
SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry)
|
||||||
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
|
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
|
||||||
SignalMessage::AuthToken { .. } => CallSignalType::Offer,
|
SignalMessage::AuthToken { .. } => CallSignalType::Offer,
|
||||||
SignalMessage::Hold => CallSignalType::Hold,
|
SignalMessage::Hold { .. } => CallSignalType::Hold,
|
||||||
SignalMessage::Unhold => CallSignalType::Unhold,
|
SignalMessage::Unhold { .. } => CallSignalType::Unhold,
|
||||||
SignalMessage::Mute => CallSignalType::Mute,
|
SignalMessage::Mute { .. } => CallSignalType::Mute,
|
||||||
SignalMessage::Unmute => CallSignalType::Unmute,
|
SignalMessage::Unmute { .. } => CallSignalType::Unmute,
|
||||||
SignalMessage::Transfer { .. } => CallSignalType::Transfer,
|
SignalMessage::Transfer { .. } => CallSignalType::Transfer,
|
||||||
SignalMessage::TransferAck => CallSignalType::Offer, // reuse
|
SignalMessage::TransferAck { .. } => CallSignalType::Offer, // reuse
|
||||||
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
|
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
|
||||||
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
|
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
|
||||||
|
SignalMessage::TransportFeedback { .. } => CallSignalType::Offer, // reuse (BWE)
|
||||||
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
|
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
|
||||||
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
|
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
|
||||||
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
|
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
|
||||||
@@ -118,19 +119,31 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
|||||||
SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer,
|
SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer,
|
||||||
SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only
|
SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only
|
||||||
SignalMessage::CallRinging { .. } => CallSignalType::Ringing,
|
SignalMessage::CallRinging { .. } => CallSignalType::Ringing,
|
||||||
SignalMessage::RegisterPresence { .. }
|
SignalMessage::RegisterPresence { .. } | SignalMessage::RegisterPresenceAck { .. } => {
|
||||||
| SignalMessage::RegisterPresenceAck { .. } => CallSignalType::Offer, // relay-only
|
CallSignalType::Offer
|
||||||
|
} // relay-only
|
||||||
// NAT reflection is a client↔relay control exchange that
|
// NAT reflection is a client↔relay control exchange that
|
||||||
// never crosses the featherChat bridge — if it ever reaches
|
// never crosses the featherChat bridge — if it ever reaches
|
||||||
// this mapper something is wrong, but we still have to give
|
// this mapper something is wrong, but we still have to give
|
||||||
// an answer. "Offer" is the generic catch-all.
|
// an answer. "Offer" is the generic catch-all.
|
||||||
SignalMessage::Reflect
|
SignalMessage::Reflect | SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
|
||||||
| SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
|
|
||||||
// Phase 4 cross-relay forwarding envelope — strictly a
|
// Phase 4 cross-relay forwarding envelope — strictly a
|
||||||
// relay-to-relay message, never rides the featherChat
|
// relay-to-relay message, never rides the featherChat
|
||||||
// bridge. Catch-all mapping for completeness.
|
// bridge. Catch-all mapping for completeness.
|
||||||
SignalMessage::FederatedSignalForward { .. } => CallSignalType::Offer,
|
SignalMessage::FederatedSignalForward { .. } => CallSignalType::Offer,
|
||||||
SignalMessage::MediaPathReport { .. } => CallSignalType::Offer, // control-plane
|
SignalMessage::MediaPathReport { .. } => CallSignalType::Offer, // control-plane
|
||||||
|
SignalMessage::CandidateUpdate { .. } => CallSignalType::IceCandidate, // mid-call re-gather
|
||||||
|
SignalMessage::HardNatProbe { .. } => CallSignalType::IceCandidate, // hard NAT coordination
|
||||||
|
SignalMessage::HardNatBirthdayStart { .. } => CallSignalType::IceCandidate, // birthday attack
|
||||||
|
SignalMessage::UpgradeProposal { .. }
|
||||||
|
| SignalMessage::UpgradeResponse { .. }
|
||||||
|
| SignalMessage::UpgradeConfirm { .. }
|
||||||
|
| SignalMessage::QualityCapability { .. } => CallSignalType::Offer, // quality negotiation
|
||||||
|
SignalMessage::PresenceList { .. } => CallSignalType::Offer, // lobby presence
|
||||||
|
SignalMessage::QualityDirective { .. } => CallSignalType::Offer, // relay-initiated
|
||||||
|
SignalMessage::Nack { .. }
|
||||||
|
| SignalMessage::PictureLossIndication { .. }
|
||||||
|
| SignalMessage::SetPriorityMode { .. } => CallSignalType::Offer, // relay-initiated (video loss recovery)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,15 +151,20 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use wzp_proto::QualityProfile;
|
use wzp_proto::QualityProfile;
|
||||||
|
use wzp_proto::default_signal_version;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn payload_roundtrip() {
|
fn payload_roundtrip() {
|
||||||
let signal = SignalMessage::CallOffer {
|
let signal = SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub: [1u8; 32],
|
identity_pub: [1u8; 32],
|
||||||
ephemeral_pub: [2u8; 32],
|
ephemeral_pub: [2u8; 32],
|
||||||
signature: vec![3u8; 64],
|
signature: vec![3u8; 64],
|
||||||
supported_profiles: vec![QualityProfile::GOOD],
|
supported_profiles: vec![QualityProfile::GOOD],
|
||||||
alias: None,
|
alias: None,
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
|
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
|
||||||
@@ -160,29 +178,53 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn signal_type_mapping() {
|
fn signal_type_mapping() {
|
||||||
let offer = SignalMessage::CallOffer {
|
let offer = SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub: [0; 32],
|
identity_pub: [0; 32],
|
||||||
ephemeral_pub: [0; 32],
|
ephemeral_pub: [0; 32],
|
||||||
signature: vec![],
|
signature: vec![],
|
||||||
supported_profiles: vec![],
|
supported_profiles: vec![],
|
||||||
alias: None,
|
alias: None,
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![],
|
||||||
};
|
};
|
||||||
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
|
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
|
||||||
|
|
||||||
let hangup = SignalMessage::Hangup {
|
let hangup = SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
};
|
};
|
||||||
assert!(matches!(signal_to_call_type(&hangup), CallSignalType::Hangup));
|
assert!(matches!(
|
||||||
|
signal_to_call_type(&hangup),
|
||||||
|
CallSignalType::Hangup
|
||||||
|
));
|
||||||
|
|
||||||
assert!(matches!(signal_to_call_type(&SignalMessage::Hold), CallSignalType::Hold));
|
assert!(matches!(
|
||||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unhold), CallSignalType::Unhold));
|
signal_to_call_type(&SignalMessage::Hold { version: default_signal_version() }),
|
||||||
assert!(matches!(signal_to_call_type(&SignalMessage::Mute), CallSignalType::Mute));
|
CallSignalType::Hold
|
||||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unmute), CallSignalType::Unmute));
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
signal_to_call_type(&SignalMessage::Unhold { version: default_signal_version() }),
|
||||||
|
CallSignalType::Unhold
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
signal_to_call_type(&SignalMessage::Mute { version: default_signal_version() }),
|
||||||
|
CallSignalType::Mute
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
signal_to_call_type(&SignalMessage::Unmute { version: default_signal_version() }),
|
||||||
|
CallSignalType::Unmute
|
||||||
|
));
|
||||||
|
|
||||||
let transfer = SignalMessage::Transfer {
|
let transfer = SignalMessage::Transfer {
|
||||||
|
version: default_signal_version(),
|
||||||
target_fingerprint: "abc".to_string(),
|
target_fingerprint: "abc".to_string(),
|
||||||
relay_addr: None,
|
relay_addr: None,
|
||||||
};
|
};
|
||||||
assert!(matches!(signal_to_call_type(&transfer), CallSignalType::Transfer));
|
assert!(matches!(
|
||||||
|
signal_to_call_type(&transfer),
|
||||||
|
CallSignalType::Transfer
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,62 @@
|
|||||||
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
|
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
|
||||||
|
|
||||||
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
||||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
use wzp_proto::{
|
||||||
|
CodecId, HangupReason, MediaTransport, QualityProfile, SignalMessage, default_signal_version,
|
||||||
|
};
|
||||||
|
|
||||||
|
const SUPPORTED_VIDEO_CODECS: &[CodecId] = &[CodecId::H264Baseline];
|
||||||
|
|
||||||
|
/// Result of a successful client-side handshake.
|
||||||
|
pub struct HandshakeResult {
|
||||||
|
pub session: Box<dyn CryptoSession>,
|
||||||
|
/// Video codec agreed with the relay. `None` if peer is audio-only.
|
||||||
|
pub video_codec: Option<CodecId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors that can occur during the client-side cryptographic handshake.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum HandshakeError {
|
||||||
|
ConnectionClosed,
|
||||||
|
ProtocolVersionMismatch { server_supported: Vec<u8> },
|
||||||
|
UnexpectedSignal(&'static str),
|
||||||
|
SignatureVerificationFailed,
|
||||||
|
KeyDerivation(String),
|
||||||
|
Transport(wzp_proto::TransportError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for HandshakeError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::ConnectionClosed => write!(f, "connection closed before receiving CallAnswer"),
|
||||||
|
Self::ProtocolVersionMismatch { server_supported } => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"protocol version mismatch: server supports {server_supported:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Self::UnexpectedSignal(expected) => write!(f, "expected CallAnswer, got {expected}"),
|
||||||
|
Self::SignatureVerificationFailed => write!(f, "callee signature verification failed"),
|
||||||
|
Self::KeyDerivation(msg) => write!(f, "key derivation failed: {msg}"),
|
||||||
|
Self::Transport(e) => write!(f, "transport error: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for HandshakeError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Transport(e) => Some(e),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<wzp_proto::TransportError> for HandshakeError {
|
||||||
|
fn from(e: wzp_proto::TransportError) -> Self {
|
||||||
|
Self::Transport(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Perform the client (caller) side of the cryptographic handshake.
|
/// Perform the client (caller) side of the cryptographic handshake.
|
||||||
///
|
///
|
||||||
@@ -18,7 +73,17 @@ pub async fn perform_handshake(
|
|||||||
transport: &dyn MediaTransport,
|
transport: &dyn MediaTransport,
|
||||||
seed: &[u8; 32],
|
seed: &[u8; 32],
|
||||||
alias: Option<&str>,
|
alias: Option<&str>,
|
||||||
) -> Result<Box<dyn CryptoSession>, anyhow::Error> {
|
) -> Result<HandshakeResult, HandshakeError> {
|
||||||
|
perform_handshake_with_video_codecs(transport, seed, alias, SUPPORTED_VIDEO_CODECS.to_vec())
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn perform_handshake_with_video_codecs(
|
||||||
|
transport: &dyn MediaTransport,
|
||||||
|
seed: &[u8; 32],
|
||||||
|
alias: Option<&str>,
|
||||||
|
video_codecs: Vec<CodecId>,
|
||||||
|
) -> Result<HandshakeResult, HandshakeError> {
|
||||||
// 1. Create key exchange from identity seed
|
// 1. Create key exchange from identity seed
|
||||||
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
|
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
|
||||||
let identity_pub = kx.identity_public_key();
|
let identity_pub = kx.identity_public_key();
|
||||||
@@ -34,6 +99,7 @@ pub async fn perform_handshake(
|
|||||||
|
|
||||||
// 4. Send CallOffer
|
// 4. Send CallOffer
|
||||||
let offer = SignalMessage::CallOffer {
|
let offer = SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature,
|
signature,
|
||||||
@@ -46,28 +112,46 @@ pub async fn perform_handshake(
|
|||||||
QualityProfile::CATASTROPHIC,
|
QualityProfile::CATASTROPHIC,
|
||||||
],
|
],
|
||||||
alias: alias.map(|s| s.to_string()),
|
alias: alias.map(|s| s.to_string()),
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs,
|
||||||
};
|
};
|
||||||
transport.send_signal(&offer).await?;
|
transport
|
||||||
|
.send_signal(&offer)
|
||||||
|
.await
|
||||||
|
.map_err(HandshakeError::Transport)?;
|
||||||
|
|
||||||
// 5. Wait for CallAnswer
|
// 5. Wait for CallAnswer — 10s timeout guards against relay not responding.
|
||||||
let answer = transport
|
let answer = tokio::time::timeout(std::time::Duration::from_secs(10), transport.recv_signal())
|
||||||
.recv_signal()
|
.await
|
||||||
.await?
|
.map_err(|_| HandshakeError::Transport(wzp_proto::TransportError::Timeout { ms: 10_000 }))?
|
||||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallAnswer"))?;
|
.map_err(HandshakeError::Transport)?
|
||||||
|
.ok_or(HandshakeError::ConnectionClosed)?;
|
||||||
|
|
||||||
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) = match answer
|
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile, video_codec) =
|
||||||
{
|
match answer {
|
||||||
SignalMessage::CallAnswer {
|
SignalMessage::CallAnswer {
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature,
|
signature,
|
||||||
chosen_profile,
|
chosen_profile,
|
||||||
} => (identity_pub, ephemeral_pub, signature, chosen_profile),
|
video_codec,
|
||||||
other => {
|
..
|
||||||
return Err(anyhow::anyhow!(
|
} => (
|
||||||
"expected CallAnswer, got {:?}",
|
identity_pub,
|
||||||
std::mem::discriminant(&other)
|
ephemeral_pub,
|
||||||
))
|
signature,
|
||||||
|
chosen_profile,
|
||||||
|
video_codec,
|
||||||
|
),
|
||||||
|
SignalMessage::Hangup {
|
||||||
|
reason: HangupReason::ProtocolVersionMismatch { server_supported },
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
return Err(HandshakeError::ProtocolVersionMismatch { server_supported });
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(HandshakeError::UnexpectedSignal("CallAnswer"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -76,13 +160,18 @@ pub async fn perform_handshake(
|
|||||||
verify_data.extend_from_slice(&callee_ephemeral_pub);
|
verify_data.extend_from_slice(&callee_ephemeral_pub);
|
||||||
verify_data.extend_from_slice(b"call-answer");
|
verify_data.extend_from_slice(b"call-answer");
|
||||||
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
|
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
|
||||||
return Err(anyhow::anyhow!("callee signature verification failed"));
|
return Err(HandshakeError::SignatureVerificationFailed);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Derive session
|
// 7. Derive session
|
||||||
let session = kx.derive_session(&callee_ephemeral_pub)?;
|
let session = kx
|
||||||
|
.derive_session(&callee_ephemeral_pub)
|
||||||
|
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
|
||||||
|
|
||||||
Ok(session)
|
Ok(HandshakeResult {
|
||||||
|
session,
|
||||||
|
video_codec,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -104,4 +193,34 @@ mod tests {
|
|||||||
&sig,
|
&sig,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handshake_result_carries_video_codec() {
|
||||||
|
// Verify that HandshakeResult has both fields accessible and that
|
||||||
|
// None is the correct default for audio-only peers.
|
||||||
|
let mut kx = WarzoneKeyExchange::from_identity_seed(&[0x55; 32]);
|
||||||
|
kx.generate_ephemeral();
|
||||||
|
let session = kx.derive_session(&[0u8; 32]).unwrap();
|
||||||
|
let hs = HandshakeResult {
|
||||||
|
session,
|
||||||
|
video_codec: None,
|
||||||
|
};
|
||||||
|
assert!(hs.video_codec.is_none());
|
||||||
|
|
||||||
|
let mut kx2 = WarzoneKeyExchange::from_identity_seed(&[0x66; 32]);
|
||||||
|
kx2.generate_ephemeral();
|
||||||
|
let session2 = kx2.derive_session(&[0u8; 32]).unwrap();
|
||||||
|
let hs2 = HandshakeResult {
|
||||||
|
session: session2,
|
||||||
|
video_codec: Some(CodecId::H264Baseline),
|
||||||
|
};
|
||||||
|
assert_eq!(hs2.video_codec, Some(CodecId::H264Baseline));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn offer_contains_h264_only() {
|
||||||
|
// Keep room video on the common denominator until Android AV1/HEVC
|
||||||
|
// send paths are proven in-device.
|
||||||
|
assert_eq!(SUPPORTED_VIDEO_CODECS, &[CodecId::H264Baseline]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
440
crates/wzp-client/src/ice_agent.rs
Normal file
440
crates/wzp-client/src/ice_agent.rs
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
//! Phase 8 (Tailscale-inspired): ICE agent for candidate lifecycle
|
||||||
|
//! management and mid-call re-gathering.
|
||||||
|
//!
|
||||||
|
//! The `IceAgent` owns the state of all candidate discovery
|
||||||
|
//! mechanisms (STUN, port mapping, host candidates) and provides:
|
||||||
|
//!
|
||||||
|
//! - `gather()`: initial candidate gathering during call setup
|
||||||
|
//! - `re_gather()`: triggered on network change, produces a
|
||||||
|
//! `CandidateUpdate` to send to the peer
|
||||||
|
//! - `apply_peer_update()`: processes peer's candidate updates
|
||||||
|
//!
|
||||||
|
//! This is NOT a full ICE agent (RFC 8445). It's the Tailscale-style
|
||||||
|
//! "gather all candidates, race them all in parallel, pick the
|
||||||
|
//! winner" approach, adapted for QUIC transport.
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use wzp_proto::{SignalMessage, default_signal_version};
|
||||||
|
|
||||||
|
use crate::dual_path::PeerCandidates;
|
||||||
|
use crate::portmap;
|
||||||
|
use crate::reflect;
|
||||||
|
use crate::stun;
|
||||||
|
|
||||||
|
/// All candidates gathered for the local side.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CandidateSet {
|
||||||
|
/// STUN-discovered server-reflexive address.
|
||||||
|
pub reflexive: Option<SocketAddr>,
|
||||||
|
/// LAN host candidates from local interfaces.
|
||||||
|
pub local: Vec<SocketAddr>,
|
||||||
|
/// Port-mapped address from NAT-PMP/PCP/UPnP.
|
||||||
|
pub mapped: Option<SocketAddr>,
|
||||||
|
/// Generation counter (monotonically increasing per call).
|
||||||
|
pub generation: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for the ICE agent.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IceAgentConfig {
|
||||||
|
/// STUN servers to use for reflexive discovery.
|
||||||
|
pub stun_config: stun::StunConfig,
|
||||||
|
/// Whether to attempt port mapping.
|
||||||
|
pub enable_portmap: bool,
|
||||||
|
/// Timeout for each discovery mechanism.
|
||||||
|
pub gather_timeout: Duration,
|
||||||
|
/// The QUIC endpoint's local port (for host candidate pairing).
|
||||||
|
pub local_v4_port: u16,
|
||||||
|
/// Optional IPv6 port.
|
||||||
|
pub local_v6_port: Option<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IceAgentConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
stun_config: stun::StunConfig::default(),
|
||||||
|
enable_portmap: true,
|
||||||
|
gather_timeout: Duration::from_secs(3),
|
||||||
|
local_v4_port: 0,
|
||||||
|
local_v6_port: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ICE agent managing candidate lifecycle.
|
||||||
|
pub struct IceAgent {
|
||||||
|
config: IceAgentConfig,
|
||||||
|
generation: AtomicU32,
|
||||||
|
call_id: String,
|
||||||
|
/// Last-seen peer generation (to filter stale updates).
|
||||||
|
peer_generation: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IceAgent {
|
||||||
|
pub fn new(call_id: String, config: IceAgentConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
generation: AtomicU32::new(0),
|
||||||
|
call_id,
|
||||||
|
peer_generation: AtomicU32::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initial candidate gathering. Runs all discovery mechanisms
|
||||||
|
/// in parallel and returns the full candidate set.
|
||||||
|
pub async fn gather(&self) -> CandidateSet {
|
||||||
|
let generation = self.generation.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Run STUN + port mapping + host candidates in parallel.
|
||||||
|
let stun_fut = stun::discover_reflexive(&self.config.stun_config);
|
||||||
|
let portmap_fut = async {
|
||||||
|
if self.config.enable_portmap && self.config.local_v4_port > 0 {
|
||||||
|
portmap::acquire_port_mapping(self.config.local_v4_port, None)
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (stun_result, portmap_result) = tokio::join!(
|
||||||
|
tokio::time::timeout(self.config.gather_timeout, stun_fut),
|
||||||
|
tokio::time::timeout(self.config.gather_timeout, portmap_fut),
|
||||||
|
);
|
||||||
|
|
||||||
|
let reflexive = stun_result.ok().and_then(|r| r.ok());
|
||||||
|
let mapped = portmap_result.ok().flatten().map(|m| m.external_addr);
|
||||||
|
let local =
|
||||||
|
reflect::local_host_candidates(self.config.local_v4_port, self.config.local_v6_port);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
generation,
|
||||||
|
reflexive = ?reflexive,
|
||||||
|
mapped = ?mapped,
|
||||||
|
local_count = local.len(),
|
||||||
|
"ice_agent: gathered candidates"
|
||||||
|
);
|
||||||
|
|
||||||
|
CandidateSet {
|
||||||
|
reflexive,
|
||||||
|
local,
|
||||||
|
mapped,
|
||||||
|
generation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-gather candidates after a network change. Increments the
|
||||||
|
/// generation counter and returns a `CandidateUpdate` signal
|
||||||
|
/// message to send to the peer.
|
||||||
|
pub async fn re_gather(&self) -> (CandidateSet, SignalMessage) {
|
||||||
|
let candidates = self.gather().await;
|
||||||
|
|
||||||
|
let update = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: self.call_id.clone(),
|
||||||
|
reflexive_addr: candidates.reflexive.map(|a| a.to_string()),
|
||||||
|
local_addrs: candidates.local.iter().map(|a| a.to_string()).collect(),
|
||||||
|
mapped_addr: candidates.mapped.map(|a| a.to_string()),
|
||||||
|
generation: candidates.generation,
|
||||||
|
};
|
||||||
|
|
||||||
|
(candidates, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a peer's candidate update. Returns `Some(PeerCandidates)`
|
||||||
|
/// if the update is newer than the last-seen generation, `None`
|
||||||
|
/// if it's stale.
|
||||||
|
pub fn apply_peer_update(&self, update: &SignalMessage) -> Option<PeerCandidates> {
|
||||||
|
let (reflexive_addr, local_addrs, mapped_addr, generation) = match update {
|
||||||
|
SignalMessage::CandidateUpdate {
|
||||||
|
reflexive_addr,
|
||||||
|
local_addrs,
|
||||||
|
mapped_addr,
|
||||||
|
generation,
|
||||||
|
..
|
||||||
|
} => (reflexive_addr, local_addrs, mapped_addr, *generation),
|
||||||
|
_ => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only accept if newer than last-seen generation.
|
||||||
|
let prev = self.peer_generation.fetch_max(generation, Ordering::AcqRel);
|
||||||
|
if generation <= prev {
|
||||||
|
tracing::debug!(
|
||||||
|
generation,
|
||||||
|
prev,
|
||||||
|
"ice_agent: ignoring stale CandidateUpdate"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let reflexive = reflexive_addr.as_deref().and_then(|s| s.parse().ok());
|
||||||
|
let local: Vec<SocketAddr> = local_addrs.iter().filter_map(|s| s.parse().ok()).collect();
|
||||||
|
let mapped = mapped_addr.as_deref().and_then(|s| s.parse().ok());
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
generation,
|
||||||
|
reflexive = ?reflexive,
|
||||||
|
mapped = ?mapped,
|
||||||
|
local_count = local.len(),
|
||||||
|
"ice_agent: applied peer candidate update"
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(PeerCandidates {
|
||||||
|
reflexive,
|
||||||
|
local,
|
||||||
|
mapped,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current generation counter.
|
||||||
|
pub fn generation(&self) -> u32 {
|
||||||
|
self.generation.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_peer_update_rejects_stale() {
|
||||||
|
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||||
|
|
||||||
|
// First update (gen=1) should succeed.
|
||||||
|
let update1 = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test-call".into(),
|
||||||
|
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||||
|
local_addrs: vec!["192.168.1.10:4433".into()],
|
||||||
|
mapped_addr: None,
|
||||||
|
generation: 1,
|
||||||
|
};
|
||||||
|
let result = agent.apply_peer_update(&update1);
|
||||||
|
assert!(result.is_some());
|
||||||
|
let candidates = result.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
candidates.reflexive,
|
||||||
|
Some("203.0.113.5:4433".parse().unwrap())
|
||||||
|
);
|
||||||
|
assert_eq!(candidates.local.len(), 1);
|
||||||
|
|
||||||
|
// Same generation (gen=1) should be rejected.
|
||||||
|
let update1b = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test-call".into(),
|
||||||
|
reflexive_addr: Some("198.51.100.9:4433".into()),
|
||||||
|
local_addrs: vec![],
|
||||||
|
mapped_addr: None,
|
||||||
|
generation: 1,
|
||||||
|
};
|
||||||
|
assert!(agent.apply_peer_update(&update1b).is_none());
|
||||||
|
|
||||||
|
// Older generation (gen=0) should be rejected.
|
||||||
|
let update0 = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test-call".into(),
|
||||||
|
reflexive_addr: Some("10.0.0.1:4433".into()),
|
||||||
|
local_addrs: vec![],
|
||||||
|
mapped_addr: None,
|
||||||
|
generation: 0,
|
||||||
|
};
|
||||||
|
assert!(agent.apply_peer_update(&update0).is_none());
|
||||||
|
|
||||||
|
// Newer generation (gen=2) should succeed.
|
||||||
|
let update2 = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test-call".into(),
|
||||||
|
reflexive_addr: Some("198.51.100.9:5555".into()),
|
||||||
|
local_addrs: vec![],
|
||||||
|
mapped_addr: Some("203.0.113.5:12345".into()),
|
||||||
|
generation: 2,
|
||||||
|
};
|
||||||
|
let result = agent.apply_peer_update(&update2);
|
||||||
|
assert!(result.is_some());
|
||||||
|
let candidates = result.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
candidates.reflexive,
|
||||||
|
Some("198.51.100.9:5555".parse().unwrap())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
candidates.mapped,
|
||||||
|
Some("203.0.113.5:12345".parse().unwrap())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_wrong_signal_returns_none() {
|
||||||
|
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||||
|
let wrong = SignalMessage::Reflect;
|
||||||
|
assert!(agent.apply_peer_update(&wrong).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generation_increments() {
|
||||||
|
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||||
|
assert_eq!(agent.generation(), 0);
|
||||||
|
// Simulate what gather() does internally
|
||||||
|
let g1 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||||
|
assert_eq!(g1, 0);
|
||||||
|
assert_eq!(agent.generation(), 1);
|
||||||
|
let g2 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||||
|
assert_eq!(g2, 1);
|
||||||
|
assert_eq!(agent.generation(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_peer_update_parses_all_fields() {
|
||||||
|
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||||
|
|
||||||
|
let update = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test-call".into(),
|
||||||
|
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||||
|
local_addrs: vec!["192.168.1.10:4433".into(), "10.0.0.5:4433".into()],
|
||||||
|
mapped_addr: Some("198.51.100.42:12345".into()),
|
||||||
|
generation: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
candidates.reflexive,
|
||||||
|
Some("203.0.113.5:4433".parse().unwrap())
|
||||||
|
);
|
||||||
|
assert_eq!(candidates.local.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
candidates.local[0],
|
||||||
|
"192.168.1.10:4433".parse::<SocketAddr>().unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
candidates.mapped,
|
||||||
|
Some("198.51.100.42:12345".parse().unwrap())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_peer_update_handles_empty_fields() {
|
||||||
|
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||||
|
|
||||||
|
let update = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test".into(),
|
||||||
|
reflexive_addr: None,
|
||||||
|
local_addrs: vec![],
|
||||||
|
mapped_addr: None,
|
||||||
|
generation: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||||
|
assert!(candidates.reflexive.is_none());
|
||||||
|
assert!(candidates.local.is_empty());
|
||||||
|
assert!(candidates.mapped.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_peer_update_skips_unparseable_addrs() {
|
||||||
|
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||||
|
|
||||||
|
let update = SignalMessage::CandidateUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
|
call_id: "test".into(),
|
||||||
|
reflexive_addr: Some("not-an-addr".into()),
|
||||||
|
local_addrs: vec![
|
||||||
|
"192.168.1.10:4433".into(),
|
||||||
|
"garbage".into(),
|
||||||
|
"10.0.0.5:4433".into(),
|
||||||
|
],
|
||||||
|
mapped_addr: Some("also-bad".into()),
|
||||||
|
generation: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||||
|
assert!(candidates.reflexive.is_none()); // unparseable
|
||||||
|
assert_eq!(candidates.local.len(), 2); // garbage filtered
|
||||||
|
assert!(candidates.mapped.is_none()); // unparseable
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_config_values() {
|
||||||
|
let cfg = IceAgentConfig::default();
|
||||||
|
assert!(cfg.enable_portmap);
|
||||||
|
assert!(cfg.gather_timeout.as_secs() > 0);
|
||||||
|
assert!(!cfg.stun_config.servers.is_empty());
|
||||||
|
assert_eq!(cfg.local_v4_port, 0);
|
||||||
|
assert!(cfg.local_v6_port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn gather_returns_candidates_even_with_no_stun() {
|
||||||
|
// With default config (port 0 = no portmap, STUN will timeout
|
||||||
|
// quickly on loopback), gather should still return host candidates.
|
||||||
|
let agent = IceAgent::new(
|
||||||
|
"test".into(),
|
||||||
|
IceAgentConfig {
|
||||||
|
stun_config: stun::StunConfig {
|
||||||
|
servers: vec![], // no servers = quick failure
|
||||||
|
timeout: Duration::from_millis(100),
|
||||||
|
},
|
||||||
|
enable_portmap: false,
|
||||||
|
gather_timeout: Duration::from_millis(200),
|
||||||
|
local_v4_port: 12345,
|
||||||
|
local_v6_port: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let candidates = agent.gather().await;
|
||||||
|
assert_eq!(candidates.generation, 0);
|
||||||
|
// Reflexive should be None (no STUN servers)
|
||||||
|
assert!(candidates.reflexive.is_none());
|
||||||
|
// Mapped should be None (portmap disabled)
|
||||||
|
assert!(candidates.mapped.is_none());
|
||||||
|
// Local candidates depend on the machine's interfaces
|
||||||
|
// but gather() should not panic.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn re_gather_produces_signal_message() {
|
||||||
|
let agent = IceAgent::new(
|
||||||
|
"call-42".into(),
|
||||||
|
IceAgentConfig {
|
||||||
|
stun_config: stun::StunConfig {
|
||||||
|
servers: vec![],
|
||||||
|
timeout: Duration::from_millis(50),
|
||||||
|
},
|
||||||
|
enable_portmap: false,
|
||||||
|
gather_timeout: Duration::from_millis(100),
|
||||||
|
local_v4_port: 4433,
|
||||||
|
local_v6_port: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let (candidates, signal) = agent.re_gather().await;
|
||||||
|
assert_eq!(candidates.generation, 0);
|
||||||
|
|
||||||
|
match signal {
|
||||||
|
SignalMessage::CandidateUpdate {
|
||||||
|
call_id,
|
||||||
|
generation,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(call_id, "call-42");
|
||||||
|
assert_eq!(generation, 0);
|
||||||
|
}
|
||||||
|
_ => panic!("expected CandidateUpdate"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second re_gather increments generation
|
||||||
|
let (candidates2, signal2) = agent.re_gather().await;
|
||||||
|
assert_eq!(candidates2.generation, 1);
|
||||||
|
match signal2 {
|
||||||
|
SignalMessage::CandidateUpdate { generation, .. } => {
|
||||||
|
assert_eq!(generation, 1);
|
||||||
|
}
|
||||||
|
_ => panic!("expected CandidateUpdate"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,14 +27,21 @@ pub mod audio_wasapi;
|
|||||||
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
||||||
pub mod audio_linux_aec;
|
pub mod audio_linux_aec;
|
||||||
pub mod bench;
|
pub mod bench;
|
||||||
|
pub mod birthday;
|
||||||
pub mod call;
|
pub mod call;
|
||||||
|
pub mod encrypted_transport;
|
||||||
pub mod drift_test;
|
pub mod drift_test;
|
||||||
|
pub mod dual_path;
|
||||||
pub mod echo_test;
|
pub mod echo_test;
|
||||||
pub mod featherchat;
|
pub mod featherchat;
|
||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
pub mod dual_path;
|
pub mod ice_agent;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
pub mod netcheck;
|
||||||
|
pub mod portmap;
|
||||||
pub mod reflect;
|
pub mod reflect;
|
||||||
|
pub mod relay_map;
|
||||||
|
pub mod stun;
|
||||||
pub mod sweep;
|
pub mod sweep;
|
||||||
|
|
||||||
// AudioPlayback: three possible backends depending on feature flags.
|
// AudioPlayback: three possible backends depending on feature flags.
|
||||||
|
|||||||
@@ -178,7 +178,10 @@ mod tests {
|
|||||||
|
|
||||||
// Immediate second write should be skipped (60s interval).
|
// Immediate second write should be skipped (60s interval).
|
||||||
let second = writer.maybe_write(&snap).unwrap();
|
let second = writer.maybe_write(&snap).unwrap();
|
||||||
assert!(!second, "second write should be skipped — interval not elapsed");
|
assert!(
|
||||||
|
!second,
|
||||||
|
"second write should be skipped — interval not elapsed"
|
||||||
|
);
|
||||||
|
|
||||||
// Clean up.
|
// Clean up.
|
||||||
let _ = std::fs::remove_file(&path);
|
let _ = std::fs::remove_file(&path);
|
||||||
|
|||||||
537
crates/wzp-client/src/netcheck.rs
Normal file
537
crates/wzp-client/src/netcheck.rs
Normal file
@@ -0,0 +1,537 @@
|
|||||||
|
//! Phase 8 (Tailscale-inspired): Comprehensive network diagnostic.
|
||||||
|
//!
|
||||||
|
//! Probes STUN servers, relay infrastructure, port mapping
|
||||||
|
//! capabilities, IPv6 reachability, and NAT hairpinning in parallel
|
||||||
|
//! to produce a `NetcheckReport` that captures the client's network
|
||||||
|
//! environment at a point in time.
|
||||||
|
//!
|
||||||
|
//! Used for:
|
||||||
|
//! - Troubleshooting connectivity issues
|
||||||
|
//! - Automatic relay selection (Phase 5)
|
||||||
|
//! - Pre-call NAT assessment
|
||||||
|
//! - Quality prediction
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::portmap::{self, PortMapProtocol};
|
||||||
|
use crate::reflect::{self, NatType};
|
||||||
|
use crate::stun::{self, StunConfig};
|
||||||
|
|
||||||
|
/// Complete network diagnostic report.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct NetcheckReport {
|
||||||
|
/// NAT type classification (from combined STUN + relay probes).
|
||||||
|
pub nat_type: NatType,
|
||||||
|
/// Server-reflexive address (consensus from probes).
|
||||||
|
pub reflexive_addr: Option<String>,
|
||||||
|
/// Whether IPv4 connectivity is available.
|
||||||
|
pub ipv4_reachable: bool,
|
||||||
|
/// Whether IPv6 connectivity is available.
|
||||||
|
pub ipv6_reachable: bool,
|
||||||
|
/// Whether the NAT supports hairpinning (loopback to own
|
||||||
|
/// reflexive address).
|
||||||
|
pub hairpin_works: Option<bool>,
|
||||||
|
/// Which port mapping protocol is available (if any).
|
||||||
|
pub port_mapping: Option<PortMapProtocol>,
|
||||||
|
/// Per-relay latency measurements.
|
||||||
|
pub relay_latencies: Vec<RelayLatency>,
|
||||||
|
/// Preferred relay (lowest latency).
|
||||||
|
pub preferred_relay: Option<String>,
|
||||||
|
/// STUN latency to first responding server (ms).
|
||||||
|
pub stun_latency_ms: Option<u32>,
|
||||||
|
/// Whether UPnP is available on the gateway.
|
||||||
|
pub upnp_available: bool,
|
||||||
|
/// Whether PCP is available on the gateway.
|
||||||
|
pub pcp_available: bool,
|
||||||
|
/// Whether NAT-PMP is available on the gateway.
|
||||||
|
pub nat_pmp_available: bool,
|
||||||
|
/// Default gateway address.
|
||||||
|
pub gateway: Option<String>,
|
||||||
|
/// Total time taken for the diagnostic (ms).
|
||||||
|
pub duration_ms: u32,
|
||||||
|
/// Individual STUN probe results.
|
||||||
|
pub stun_probes: Vec<reflect::NatProbeResult>,
|
||||||
|
/// NAT port allocation pattern (sequential vs random).
|
||||||
|
pub port_allocation: Option<stun::PortAllocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency to a specific relay.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct RelayLatency {
|
||||||
|
pub name: String,
|
||||||
|
pub addr: String,
|
||||||
|
pub rtt_ms: Option<u32>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for the netcheck run.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct NetcheckConfig {
|
||||||
|
/// STUN servers to probe.
|
||||||
|
pub stun_config: StunConfig,
|
||||||
|
/// Relay servers to probe (name, address pairs).
|
||||||
|
pub relays: Vec<(String, SocketAddr)>,
|
||||||
|
/// Per-probe timeout.
|
||||||
|
pub timeout: Duration,
|
||||||
|
/// Whether to test port mapping.
|
||||||
|
pub test_portmap: bool,
|
||||||
|
/// Whether to test IPv6.
|
||||||
|
pub test_ipv6: bool,
|
||||||
|
/// Local port for port mapping test (0 = skip).
|
||||||
|
pub local_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NetcheckConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
stun_config: StunConfig::default(),
|
||||||
|
relays: Vec::new(),
|
||||||
|
timeout: Duration::from_secs(5),
|
||||||
|
test_portmap: true,
|
||||||
|
test_ipv6: true,
|
||||||
|
local_port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a comprehensive network diagnostic.
|
||||||
|
///
|
||||||
|
/// Probes run in parallel for speed — the total time is bounded
|
||||||
|
/// by the slowest individual probe, not the sum.
|
||||||
|
pub async fn run_netcheck(config: &NetcheckConfig) -> NetcheckReport {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
// Run all probes in parallel.
|
||||||
|
let stun_fut = stun::probe_stun_servers(&config.stun_config);
|
||||||
|
let relay_fut = probe_relays(&config.relays, config.timeout);
|
||||||
|
let portmap_fut = probe_portmap(config.test_portmap, config.local_port);
|
||||||
|
let gateway_fut = portmap::default_gateway();
|
||||||
|
let ipv6_fut = test_ipv6(config.test_ipv6, config.timeout);
|
||||||
|
let port_alloc_fut = stun::detect_port_allocation(&config.stun_config);
|
||||||
|
|
||||||
|
let (
|
||||||
|
stun_probes,
|
||||||
|
relay_latencies,
|
||||||
|
portmap_result,
|
||||||
|
gateway_result,
|
||||||
|
ipv6_reachable,
|
||||||
|
port_alloc_result,
|
||||||
|
) = tokio::join!(
|
||||||
|
stun_fut,
|
||||||
|
relay_fut,
|
||||||
|
portmap_fut,
|
||||||
|
gateway_result_fut(gateway_fut),
|
||||||
|
ipv6_fut,
|
||||||
|
port_alloc_fut
|
||||||
|
);
|
||||||
|
|
||||||
|
// Classify NAT from STUN probes.
|
||||||
|
let (nat_type, consensus_addr) = reflect::classify_nat(&stun_probes);
|
||||||
|
|
||||||
|
// Determine STUN latency (first successful probe).
|
||||||
|
let stun_latency_ms = stun_probes.iter().filter_map(|p| p.latency_ms).min();
|
||||||
|
|
||||||
|
// IPv4 reachable if any STUN probe succeeded.
|
||||||
|
let ipv4_reachable = stun_probes.iter().any(|p| p.observed_addr.is_some());
|
||||||
|
|
||||||
|
// Preferred relay = lowest RTT.
|
||||||
|
let preferred_relay = relay_latencies
|
||||||
|
.iter()
|
||||||
|
.filter_map(|r| r.rtt_ms.map(|rtt| (r.name.clone(), rtt)))
|
||||||
|
.min_by_key(|(_, rtt)| *rtt)
|
||||||
|
.map(|(name, _)| name);
|
||||||
|
|
||||||
|
// Port mapping availability.
|
||||||
|
let (port_mapping, nat_pmp_available, pcp_available, upnp_available) = match portmap_result {
|
||||||
|
Some(mapping) => {
|
||||||
|
let proto = mapping.protocol;
|
||||||
|
(
|
||||||
|
Some(proto),
|
||||||
|
proto == PortMapProtocol::NatPmp,
|
||||||
|
proto == PortMapProtocol::Pcp,
|
||||||
|
proto == PortMapProtocol::UPnP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
None => (None, false, false, false),
|
||||||
|
};
|
||||||
|
|
||||||
|
let gateway = match gateway_result {
|
||||||
|
Ok(gw) => Some(gw.to_string()),
|
||||||
|
Err(_) => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
NetcheckReport {
|
||||||
|
nat_type,
|
||||||
|
reflexive_addr: consensus_addr,
|
||||||
|
ipv4_reachable,
|
||||||
|
ipv6_reachable,
|
||||||
|
hairpin_works: None, // TODO: implement hairpin test
|
||||||
|
port_mapping,
|
||||||
|
relay_latencies,
|
||||||
|
preferred_relay,
|
||||||
|
stun_latency_ms,
|
||||||
|
upnp_available,
|
||||||
|
pcp_available,
|
||||||
|
nat_pmp_available,
|
||||||
|
gateway,
|
||||||
|
duration_ms: start.elapsed().as_millis() as u32,
|
||||||
|
stun_probes,
|
||||||
|
port_allocation: Some(port_alloc_result.allocation),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Probe relay latencies via reflect.
|
||||||
|
async fn probe_relays(relays: &[(String, SocketAddr)], timeout: Duration) -> Vec<RelayLatency> {
|
||||||
|
if relays.is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let timeout_ms = timeout.as_millis() as u64;
|
||||||
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
for (name, addr) in relays {
|
||||||
|
let name = name.clone();
|
||||||
|
let addr = *addr;
|
||||||
|
set.spawn(async move {
|
||||||
|
let start = Instant::now();
|
||||||
|
match reflect::probe_reflect_addr(addr, timeout_ms, None).await {
|
||||||
|
Ok((_observed, _latency)) => RelayLatency {
|
||||||
|
name,
|
||||||
|
addr: addr.to_string(),
|
||||||
|
rtt_ms: Some(start.elapsed().as_millis() as u32),
|
||||||
|
error: None,
|
||||||
|
},
|
||||||
|
Err(e) => RelayLatency {
|
||||||
|
name,
|
||||||
|
addr: addr.to_string(),
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some(e),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::with_capacity(relays.len());
|
||||||
|
while let Some(join_result) = set.join_next().await {
|
||||||
|
match join_result {
|
||||||
|
Ok(r) => results.push(r),
|
||||||
|
Err(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by RTT (lowest first).
|
||||||
|
results.sort_by_key(|r| r.rtt_ms.unwrap_or(u32::MAX));
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt port mapping and return the mapping if successful.
|
||||||
|
async fn probe_portmap(enabled: bool, local_port: u16) -> Option<portmap::PortMapping> {
|
||||||
|
if !enabled || local_port == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
portmap::acquire_port_mapping(local_port, None).await.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap the gateway future to handle the Result.
|
||||||
|
async fn gateway_result_fut(
|
||||||
|
fut: impl std::future::Future<Output = Result<std::net::Ipv4Addr, portmap::PortMapError>>,
|
||||||
|
) -> Result<std::net::Ipv4Addr, portmap::PortMapError> {
|
||||||
|
fut.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test IPv6 connectivity by attempting to bind and send on an IPv6 socket.
|
||||||
|
async fn test_ipv6(enabled: bool, timeout: Duration) -> bool {
|
||||||
|
if !enabled {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve and connect to an IPv6 STUN server.
|
||||||
|
let result = tokio::time::timeout(timeout, async {
|
||||||
|
let sock = tokio::net::UdpSocket::bind("[::]:0").await.ok()?;
|
||||||
|
// Try Google's IPv6 STUN — if DNS resolves to an AAAA record
|
||||||
|
// and we can send a packet, IPv6 is working.
|
||||||
|
let addr = stun::resolve_stun_server("stun.l.google.com:19302")
|
||||||
|
.await
|
||||||
|
.ok()?;
|
||||||
|
if addr.is_ipv6() {
|
||||||
|
sock.send_to(&[0u8; 1], addr).await.ok()?;
|
||||||
|
Some(true)
|
||||||
|
} else {
|
||||||
|
// Server resolved to IPv4 — try binding to [::] at least
|
||||||
|
Some(false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Some(true)) => true,
|
||||||
|
_ => {
|
||||||
|
// Fallback: can we at least bind an IPv6 socket?
|
||||||
|
tokio::net::UdpSocket::bind("[::]:0").await.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Format a netcheck report as a human-readable string.
|
||||||
|
pub fn format_report(report: &NetcheckReport) -> String {
|
||||||
|
let mut out = String::new();
|
||||||
|
|
||||||
|
out.push_str(&format!("=== WarzonePhone Netcheck ===\n\n"));
|
||||||
|
out.push_str(&format!("NAT Type: {:?}\n", report.nat_type));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"Reflexive Addr: {}\n",
|
||||||
|
report.reflexive_addr.as_deref().unwrap_or("(unknown)")
|
||||||
|
));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"IPv4: {}\n",
|
||||||
|
if report.ipv4_reachable { "yes" } else { "no" }
|
||||||
|
));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"IPv6: {}\n",
|
||||||
|
if report.ipv6_reachable { "yes" } else { "no" }
|
||||||
|
));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"Gateway: {}\n",
|
||||||
|
report.gateway.as_deref().unwrap_or("(unknown)")
|
||||||
|
));
|
||||||
|
|
||||||
|
if let Some(ref alloc) = report.port_allocation {
|
||||||
|
out.push_str(&format!("Port Alloc: {alloc}\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
out.push_str(&format!("\n--- Port Mapping ---\n"));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"NAT-PMP: {} PCP: {} UPnP: {}\n",
|
||||||
|
if report.nat_pmp_available {
|
||||||
|
"yes"
|
||||||
|
} else {
|
||||||
|
"no"
|
||||||
|
},
|
||||||
|
if report.pcp_available { "yes" } else { "no" },
|
||||||
|
if report.upnp_available { "yes" } else { "no" },
|
||||||
|
));
|
||||||
|
if let Some(proto) = &report.port_mapping {
|
||||||
|
out.push_str(&format!("Active mapping: {:?}\n", proto));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !report.stun_probes.is_empty() {
|
||||||
|
out.push_str(&format!("\n--- STUN Probes ---\n"));
|
||||||
|
for p in &report.stun_probes {
|
||||||
|
out.push_str(&format!(
|
||||||
|
" {} → {} ({}ms){}\n",
|
||||||
|
p.relay_name,
|
||||||
|
p.observed_addr.as_deref().unwrap_or("failed"),
|
||||||
|
p.latency_ms
|
||||||
|
.map(|ms| ms.to_string())
|
||||||
|
.unwrap_or_else(|| "-".into()),
|
||||||
|
p.error
|
||||||
|
.as_ref()
|
||||||
|
.map(|e| format!(" [{e}]"))
|
||||||
|
.unwrap_or_default(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !report.relay_latencies.is_empty() {
|
||||||
|
out.push_str(&format!("\n--- Relay Latencies ---\n"));
|
||||||
|
for r in &report.relay_latencies {
|
||||||
|
out.push_str(&format!(
|
||||||
|
" {} ({}) → {}ms{}\n",
|
||||||
|
r.name,
|
||||||
|
r.addr,
|
||||||
|
r.rtt_ms
|
||||||
|
.map(|ms| ms.to_string())
|
||||||
|
.unwrap_or_else(|| "-".into()),
|
||||||
|
r.error
|
||||||
|
.as_ref()
|
||||||
|
.map(|e| format!(" [{e}]"))
|
||||||
|
.unwrap_or_default(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if let Some(ref pref) = report.preferred_relay {
|
||||||
|
out.push_str(&format!(" Preferred: {pref}\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out.push_str(&format!("\nCompleted in {}ms\n", report.duration_ms));
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_config_has_stun_servers() {
|
||||||
|
let config = NetcheckConfig::default();
|
||||||
|
assert!(!config.stun_config.servers.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn format_report_produces_output() {
|
||||||
|
let report = NetcheckReport {
|
||||||
|
nat_type: NatType::Cone,
|
||||||
|
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||||
|
ipv4_reachable: true,
|
||||||
|
ipv6_reachable: false,
|
||||||
|
hairpin_works: None,
|
||||||
|
port_mapping: None,
|
||||||
|
relay_latencies: vec![RelayLatency {
|
||||||
|
name: "relay-1".into(),
|
||||||
|
addr: "10.0.0.1:4433".into(),
|
||||||
|
rtt_ms: Some(25),
|
||||||
|
error: None,
|
||||||
|
}],
|
||||||
|
preferred_relay: Some("relay-1".into()),
|
||||||
|
stun_latency_ms: Some(15),
|
||||||
|
upnp_available: false,
|
||||||
|
pcp_available: false,
|
||||||
|
nat_pmp_available: false,
|
||||||
|
gateway: Some("192.168.1.1".into()),
|
||||||
|
duration_ms: 1500,
|
||||||
|
stun_probes: vec![],
|
||||||
|
port_allocation: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = format_report(&report);
|
||||||
|
assert!(text.contains("Cone"));
|
||||||
|
assert!(text.contains("203.0.113.5:4433"));
|
||||||
|
assert!(text.contains("relay-1"));
|
||||||
|
assert!(text.contains("1500ms"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn report_serializes_to_json() {
|
||||||
|
let report = NetcheckReport {
|
||||||
|
nat_type: NatType::Cone,
|
||||||
|
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||||
|
ipv4_reachable: true,
|
||||||
|
ipv6_reachable: false,
|
||||||
|
hairpin_works: None,
|
||||||
|
port_mapping: Some(PortMapProtocol::NatPmp),
|
||||||
|
relay_latencies: vec![],
|
||||||
|
preferred_relay: None,
|
||||||
|
stun_latency_ms: Some(25),
|
||||||
|
upnp_available: false,
|
||||||
|
pcp_available: false,
|
||||||
|
nat_pmp_available: true,
|
||||||
|
gateway: Some("192.168.1.1".into()),
|
||||||
|
duration_ms: 500,
|
||||||
|
stun_probes: vec![],
|
||||||
|
port_allocation: Some(stun::PortAllocation::Sequential { delta: 1 }),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&report).unwrap();
|
||||||
|
assert!(json.contains("Cone"));
|
||||||
|
assert!(json.contains("203.0.113.5:4433"));
|
||||||
|
assert!(json.contains("NatPmp"));
|
||||||
|
|
||||||
|
// Roundtrip
|
||||||
|
let decoded: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(decoded["ipv4_reachable"], true);
|
||||||
|
assert_eq!(decoded["ipv6_reachable"], false);
|
||||||
|
assert_eq!(decoded["stun_latency_ms"], 25);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn relay_latency_serializes() {
|
||||||
|
let lat = RelayLatency {
|
||||||
|
name: "eu-west".into(),
|
||||||
|
addr: "10.0.0.1:4433".into(),
|
||||||
|
rtt_ms: Some(42),
|
||||||
|
error: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&lat).unwrap();
|
||||||
|
assert!(json.contains("eu-west"));
|
||||||
|
assert!(json.contains("42"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn format_report_empty_relays() {
|
||||||
|
let report = NetcheckReport {
|
||||||
|
nat_type: NatType::Unknown,
|
||||||
|
reflexive_addr: None,
|
||||||
|
ipv4_reachable: false,
|
||||||
|
ipv6_reachable: false,
|
||||||
|
hairpin_works: None,
|
||||||
|
port_mapping: None,
|
||||||
|
relay_latencies: vec![],
|
||||||
|
preferred_relay: None,
|
||||||
|
stun_latency_ms: None,
|
||||||
|
upnp_available: false,
|
||||||
|
pcp_available: false,
|
||||||
|
nat_pmp_available: false,
|
||||||
|
gateway: None,
|
||||||
|
duration_ms: 100,
|
||||||
|
stun_probes: vec![],
|
||||||
|
port_allocation: None,
|
||||||
|
};
|
||||||
|
let text = format_report(&report);
|
||||||
|
assert!(text.contains("Unknown"));
|
||||||
|
assert!(text.contains("(unknown)")); // reflexive addr
|
||||||
|
assert!(text.contains("100ms"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn format_report_with_stun_probes() {
|
||||||
|
let report = NetcheckReport {
|
||||||
|
nat_type: NatType::SymmetricPort,
|
||||||
|
reflexive_addr: None,
|
||||||
|
ipv4_reachable: true,
|
||||||
|
ipv6_reachable: true,
|
||||||
|
hairpin_works: Some(false),
|
||||||
|
port_mapping: Some(PortMapProtocol::UPnP),
|
||||||
|
relay_latencies: vec![
|
||||||
|
RelayLatency {
|
||||||
|
name: "us-east".into(),
|
||||||
|
addr: "10.0.0.1:4433".into(),
|
||||||
|
rtt_ms: Some(15),
|
||||||
|
error: None,
|
||||||
|
},
|
||||||
|
RelayLatency {
|
||||||
|
name: "eu-west".into(),
|
||||||
|
addr: "10.0.0.2:4433".into(),
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some("timeout".into()),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
preferred_relay: Some("us-east".into()),
|
||||||
|
stun_latency_ms: Some(20),
|
||||||
|
upnp_available: true,
|
||||||
|
pcp_available: false,
|
||||||
|
nat_pmp_available: false,
|
||||||
|
gateway: Some("192.168.0.1".into()),
|
||||||
|
duration_ms: 3000,
|
||||||
|
stun_probes: vec![reflect::NatProbeResult {
|
||||||
|
relay_name: "stun:google".into(),
|
||||||
|
relay_addr: "74.125.250.129:19302".into(),
|
||||||
|
observed_addr: Some("203.0.113.5:12345".into()),
|
||||||
|
latency_ms: Some(20),
|
||||||
|
error: None,
|
||||||
|
}],
|
||||||
|
port_allocation: Some(stun::PortAllocation::Random),
|
||||||
|
};
|
||||||
|
let text = format_report(&report);
|
||||||
|
assert!(text.contains("SymmetricPort"));
|
||||||
|
assert!(text.contains("us-east"));
|
||||||
|
assert!(text.contains("eu-west"));
|
||||||
|
assert!(text.contains("Preferred: us-east"));
|
||||||
|
assert!(text.contains("UPnP: yes"));
|
||||||
|
assert!(text.contains("stun:google"));
|
||||||
|
assert!(text.contains("3000ms"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Integration test: run actual netcheck (requires network).
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn integration_netcheck() {
|
||||||
|
let config = NetcheckConfig::default();
|
||||||
|
let report = run_netcheck(&config).await;
|
||||||
|
println!("{}", format_report(&report));
|
||||||
|
assert!(report.duration_ms > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
1164
crates/wzp-client/src/portmap.rs
Normal file
1164
crates/wzp-client/src/portmap.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -30,8 +30,8 @@ use std::net::SocketAddr;
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use wzp_proto::{MediaTransport, SignalMessage};
|
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||||
use wzp_transport::{client_config, create_endpoint, QuinnTransport};
|
use wzp_transport::{QuinnTransport, client_config, create_endpoint};
|
||||||
|
|
||||||
/// Result of one probe against one relay. Always returned so the
|
/// Result of one probe against one relay. Always returned so the
|
||||||
/// UI can render per-relay status even when some fail.
|
/// UI can render per-relay status even when some fail.
|
||||||
@@ -110,8 +110,7 @@ pub async fn probe_reflect_addr(
|
|||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let probe = async {
|
let probe = async {
|
||||||
// Open the signal connection.
|
// Open the signal connection.
|
||||||
let conn =
|
let conn = wzp_transport::connect(&endpoint, relay, "_signal", client_config())
|
||||||
wzp_transport::connect(&endpoint, relay, "_signal", client_config())
|
|
||||||
.await
|
.await
|
||||||
.map_err(|e| format!("connect: {e}"))?;
|
.map_err(|e| format!("connect: {e}"))?;
|
||||||
let transport = QuinnTransport::new(conn);
|
let transport = QuinnTransport::new(conn);
|
||||||
@@ -124,6 +123,7 @@ pub async fn probe_reflect_addr(
|
|||||||
// path does in desktop/src-tauri/src/lib.rs register_signal.
|
// path does in desktop/src-tauri/src/lib.rs register_signal.
|
||||||
transport
|
transport
|
||||||
.send_signal(&SignalMessage::RegisterPresence {
|
.send_signal(&SignalMessage::RegisterPresence {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub: [0u8; 32],
|
identity_pub: [0u8; 32],
|
||||||
signature: vec![],
|
signature: vec![],
|
||||||
alias: None,
|
alias: None,
|
||||||
@@ -151,7 +151,7 @@ pub async fn probe_reflect_addr(
|
|||||||
.map_err(|e| format!("send Reflect: {e}"))?;
|
.map_err(|e| format!("send Reflect: {e}"))?;
|
||||||
|
|
||||||
match transport.recv_signal().await {
|
match transport.recv_signal().await {
|
||||||
Ok(Some(SignalMessage::ReflectResponse { observed_addr })) => {
|
Ok(Some(SignalMessage::ReflectResponse { observed_addr, .. })) => {
|
||||||
let parsed: SocketAddr = observed_addr
|
let parsed: SocketAddr = observed_addr
|
||||||
.parse()
|
.parse()
|
||||||
.map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?;
|
.map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?;
|
||||||
@@ -473,6 +473,40 @@ pub fn classify_nat(probes: &[NatProbeResult]) -> (NatType, Option<String>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Enhanced NAT detection that combines relay-based reflection with
|
||||||
|
/// public STUN server probes for more robust classification.
|
||||||
|
///
|
||||||
|
/// Runs both probe sets concurrently:
|
||||||
|
/// 1. Relay probes via `detect_nat_type` (existing behavior)
|
||||||
|
/// 2. Public STUN probes via `probe_stun_servers`
|
||||||
|
///
|
||||||
|
/// Merges all results and classifies. More probes = higher confidence
|
||||||
|
/// in the NAT type classification. Falls back gracefully: if STUN
|
||||||
|
/// servers are unreachable, relay probes still work (and vice versa).
|
||||||
|
pub async fn detect_nat_type_with_stun(
|
||||||
|
relays: Vec<(String, SocketAddr)>,
|
||||||
|
timeout_ms: u64,
|
||||||
|
shared_endpoint: Option<wzp_transport::Endpoint>,
|
||||||
|
stun_config: &crate::stun::StunConfig,
|
||||||
|
) -> NatDetection {
|
||||||
|
// Run relay probes and STUN probes concurrently.
|
||||||
|
let relay_fut = detect_nat_type(relays, timeout_ms, shared_endpoint);
|
||||||
|
let stun_fut = crate::stun::probe_stun_servers(stun_config);
|
||||||
|
|
||||||
|
let (relay_detection, stun_probes) = tokio::join!(relay_fut, stun_fut);
|
||||||
|
|
||||||
|
// Merge all probes and re-classify.
|
||||||
|
let mut all_probes = relay_detection.probes;
|
||||||
|
all_probes.extend(stun_probes);
|
||||||
|
|
||||||
|
let (nat_type, consensus_addr) = classify_nat(&all_probes);
|
||||||
|
NatDetection {
|
||||||
|
probes: all_probes,
|
||||||
|
nat_type,
|
||||||
|
consensus_addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Unit tests for the pure classifier ───────────────────────────
|
// ── Unit tests for the pure classifier ───────────────────────────
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -506,10 +540,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn classify_two_identical_is_cone() {
|
fn classify_two_identical_is_cone() {
|
||||||
let probes = vec![
|
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:4433"))];
|
||||||
mk(Some("192.0.2.1:4433")),
|
|
||||||
mk(Some("192.0.2.1:4433")),
|
|
||||||
];
|
|
||||||
let (nt, addr) = classify_nat(&probes);
|
let (nt, addr) = classify_nat(&probes);
|
||||||
assert_eq!(nt, NatType::Cone);
|
assert_eq!(nt, NatType::Cone);
|
||||||
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
|
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
|
||||||
@@ -517,10 +548,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn classify_same_ip_different_ports_is_symmetric() {
|
fn classify_same_ip_different_ports_is_symmetric() {
|
||||||
let probes = vec![
|
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:51234"))];
|
||||||
mk(Some("192.0.2.1:4433")),
|
|
||||||
mk(Some("192.0.2.1:51234")),
|
|
||||||
];
|
|
||||||
let (nt, addr) = classify_nat(&probes);
|
let (nt, addr) = classify_nat(&probes);
|
||||||
assert_eq!(nt, NatType::SymmetricPort);
|
assert_eq!(nt, NatType::SymmetricPort);
|
||||||
assert!(addr.is_none());
|
assert!(addr.is_none());
|
||||||
@@ -528,10 +556,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn classify_different_ips_is_multiple() {
|
fn classify_different_ips_is_multiple() {
|
||||||
let probes = vec![
|
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("198.51.100.9:4433"))];
|
||||||
mk(Some("192.0.2.1:4433")),
|
|
||||||
mk(Some("198.51.100.9:4433")),
|
|
||||||
];
|
|
||||||
let (nt, addr) = classify_nat(&probes);
|
let (nt, addr) = classify_nat(&probes);
|
||||||
assert_eq!(nt, NatType::Multiple);
|
assert_eq!(nt, NatType::Multiple);
|
||||||
assert!(addr.is_none());
|
assert!(addr.is_none());
|
||||||
|
|||||||
337
crates/wzp-client/src/relay_map.rs
Normal file
337
crates/wzp-client/src/relay_map.rs
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
//! Phase 8 (Tailscale-inspired): Relay map for automatic relay
|
||||||
|
//! selection based on latency.
|
||||||
|
//!
|
||||||
|
//! Maintains a sorted list of known relays with their measured
|
||||||
|
//! latencies. Used during call setup to pick the lowest-latency
|
||||||
|
//! relay, and by netcheck to report relay health.
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
/// A known relay endpoint with measured latency.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct RelayEntry {
|
||||||
|
/// Human-readable name (e.g., "us-east", "eu-west").
|
||||||
|
pub name: String,
|
||||||
|
/// Relay address.
|
||||||
|
pub addr: SocketAddr,
|
||||||
|
/// Geographic region (from RegisterPresenceAck).
|
||||||
|
pub region: Option<String>,
|
||||||
|
/// Last measured RTT (ms).
|
||||||
|
pub rtt_ms: Option<u32>,
|
||||||
|
/// When the RTT was last measured.
|
||||||
|
#[serde(skip)]
|
||||||
|
pub last_probed: Option<Instant>,
|
||||||
|
/// Whether this relay is currently reachable.
|
||||||
|
pub reachable: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sorted relay map. Entries are ordered by RTT (lowest first).
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct RelayMap {
|
||||||
|
entries: Vec<RelayEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RelayMap {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
entries: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add or update a relay entry.
|
||||||
|
pub fn upsert(&mut self, name: &str, addr: SocketAddr, region: Option<String>) {
|
||||||
|
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||||
|
entry.name = name.to_string();
|
||||||
|
if region.is_some() {
|
||||||
|
entry.region = region;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.entries.push(RelayEntry {
|
||||||
|
name: name.to_string(),
|
||||||
|
addr,
|
||||||
|
region,
|
||||||
|
rtt_ms: None,
|
||||||
|
last_probed: None,
|
||||||
|
reachable: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update RTT measurement for a relay.
|
||||||
|
pub fn update_rtt(&mut self, addr: SocketAddr, rtt_ms: u32) {
|
||||||
|
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||||
|
entry.rtt_ms = Some(rtt_ms);
|
||||||
|
entry.last_probed = Some(Instant::now());
|
||||||
|
entry.reachable = true;
|
||||||
|
}
|
||||||
|
self.sort();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a relay as unreachable.
|
||||||
|
pub fn mark_unreachable(&mut self, addr: SocketAddr) {
|
||||||
|
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||||
|
entry.reachable = false;
|
||||||
|
entry.last_probed = Some(Instant::now());
|
||||||
|
}
|
||||||
|
self.sort();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the preferred (lowest-latency, reachable) relay.
|
||||||
|
pub fn preferred(&self) -> Option<&RelayEntry> {
|
||||||
|
self.entries
|
||||||
|
.iter()
|
||||||
|
.find(|e| e.reachable && e.rtt_ms.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all entries, sorted by RTT.
|
||||||
|
pub fn entries(&self) -> &[RelayEntry] {
|
||||||
|
&self.entries
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Populate from a `RegisterPresenceAck.available_relays` list.
|
||||||
|
/// Each entry is "name|addr" format.
|
||||||
|
pub fn populate_from_ack(&mut self, relays: &[String], relay_region: Option<&str>) {
|
||||||
|
for entry_str in relays {
|
||||||
|
if let Some((name, addr_str)) = entry_str.split_once('|') {
|
||||||
|
if let Ok(addr) = addr_str.parse::<SocketAddr>() {
|
||||||
|
self.upsert(name, addr, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If the ack included a region for the current relay, we
|
||||||
|
// could tag it — but we'd need to know which relay we're
|
||||||
|
// connected to. Left for the caller to handle.
|
||||||
|
let _ = relay_region;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if any entry has a stale probe (older than `max_age`).
|
||||||
|
pub fn needs_reprobe(&self, max_age: Duration) -> bool {
|
||||||
|
self.entries.iter().any(|e| match e.last_probed {
|
||||||
|
None => true,
|
||||||
|
Some(t) => t.elapsed() > max_age,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get entries that need reprobing.
|
||||||
|
pub fn stale_entries(&self, max_age: Duration) -> Vec<(String, SocketAddr)> {
|
||||||
|
self.entries
|
||||||
|
.iter()
|
||||||
|
.filter(|e| match e.last_probed {
|
||||||
|
None => true,
|
||||||
|
Some(t) => t.elapsed() > max_age,
|
||||||
|
})
|
||||||
|
.map(|e| (e.name.clone(), e.addr))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sort(&mut self) {
|
||||||
|
self.entries.sort_by_key(|e| {
|
||||||
|
if e.reachable {
|
||||||
|
e.rtt_ms.unwrap_or(u32::MAX)
|
||||||
|
} else {
|
||||||
|
u32::MAX
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preferred_returns_lowest_rtt() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||||
|
let a3: SocketAddr = "10.0.0.3:4433".parse().unwrap();
|
||||||
|
|
||||||
|
map.upsert("slow", a1, None);
|
||||||
|
map.upsert("fast", a2, None);
|
||||||
|
map.upsert("mid", a3, None);
|
||||||
|
|
||||||
|
map.update_rtt(a1, 200);
|
||||||
|
map.update_rtt(a2, 15);
|
||||||
|
map.update_rtt(a3, 80);
|
||||||
|
|
||||||
|
let pref = map.preferred().unwrap();
|
||||||
|
assert_eq!(pref.addr, a2);
|
||||||
|
assert_eq!(pref.rtt_ms, Some(15));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unreachable_not_preferred() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||||
|
|
||||||
|
map.upsert("fast-dead", a1, None);
|
||||||
|
map.upsert("slow-alive", a2, None);
|
||||||
|
|
||||||
|
map.update_rtt(a1, 5);
|
||||||
|
map.update_rtt(a2, 200);
|
||||||
|
map.mark_unreachable(a1);
|
||||||
|
|
||||||
|
let pref = map.preferred().unwrap();
|
||||||
|
assert_eq!(pref.addr, a2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn populate_from_ack() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
map.populate_from_ack(
|
||||||
|
&[
|
||||||
|
"us-east|203.0.113.5:4433".into(),
|
||||||
|
"eu-west|198.51.100.9:4433".into(),
|
||||||
|
],
|
||||||
|
Some("us-east"),
|
||||||
|
);
|
||||||
|
assert_eq!(map.entries().len(), 2);
|
||||||
|
assert_eq!(map.entries()[0].name, "us-east");
|
||||||
|
assert_eq!(map.entries()[1].name, "eu-west");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn upsert_updates_existing() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
map.upsert("old-name", addr, None);
|
||||||
|
map.upsert("new-name", addr, Some("us-west".into()));
|
||||||
|
assert_eq!(map.entries().len(), 1);
|
||||||
|
assert_eq!(map.entries()[0].name, "new-name");
|
||||||
|
assert_eq!(map.entries()[0].region, Some("us-west".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn upsert_preserves_region_when_none() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
map.upsert("relay", addr, Some("eu-west".into()));
|
||||||
|
map.upsert("relay", addr, None); // region is None
|
||||||
|
// Should keep the original region
|
||||||
|
assert_eq!(map.entries()[0].region, Some("eu-west".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preferred_returns_none_on_empty() {
|
||||||
|
let map = RelayMap::new();
|
||||||
|
assert!(map.preferred().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preferred_returns_none_when_all_unreachable() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
map.upsert("relay", addr, None);
|
||||||
|
// Not update_rtt'd, so reachable=false
|
||||||
|
assert!(map.preferred().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn needs_reprobe_empty_is_false() {
|
||||||
|
let map = RelayMap::new();
|
||||||
|
// No entries → nothing to reprobe
|
||||||
|
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn needs_reprobe_never_probed() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
map.upsert("relay", "10.0.0.1:4433".parse().unwrap(), None);
|
||||||
|
assert!(map.needs_reprobe(Duration::from_secs(60)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn needs_reprobe_fresh_is_false() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
map.upsert("relay", addr, None);
|
||||||
|
map.update_rtt(addr, 50);
|
||||||
|
// Just probed, so 60s max_age should not trigger
|
||||||
|
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stale_entries_returns_unprobed() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||||
|
map.upsert("probed", a1, None);
|
||||||
|
map.upsert("stale", a2, None);
|
||||||
|
map.update_rtt(a1, 50);
|
||||||
|
|
||||||
|
let stale = map.stale_entries(Duration::from_secs(60));
|
||||||
|
assert_eq!(stale.len(), 1);
|
||||||
|
assert_eq!(stale[0].1, a2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sort_stability_with_equal_rtt() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||||
|
map.upsert("first", a1, None);
|
||||||
|
map.upsert("second", a2, None);
|
||||||
|
map.update_rtt(a1, 50);
|
||||||
|
map.update_rtt(a2, 50);
|
||||||
|
|
||||||
|
// Both have same RTT — sort should be stable (insertion order)
|
||||||
|
assert_eq!(map.entries().len(), 2);
|
||||||
|
// Both are valid preferred relays
|
||||||
|
assert!(map.preferred().is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn populate_from_ack_skips_malformed() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
map.populate_from_ack(
|
||||||
|
&[
|
||||||
|
"good|10.0.0.1:4433".into(),
|
||||||
|
"no-pipe-separator".into(),
|
||||||
|
"bad-addr|not-a-socket-addr".into(),
|
||||||
|
"also-good|10.0.0.2:4433".into(),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
assert_eq!(map.entries().len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mark_unreachable_sorts_to_end() {
|
||||||
|
let mut map = RelayMap::new();
|
||||||
|
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||||
|
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||||
|
map.upsert("fast", a1, None);
|
||||||
|
map.upsert("slow", a2, None);
|
||||||
|
map.update_rtt(a1, 10);
|
||||||
|
map.update_rtt(a2, 200);
|
||||||
|
|
||||||
|
assert_eq!(map.preferred().unwrap().addr, a1);
|
||||||
|
|
||||||
|
map.mark_unreachable(a1);
|
||||||
|
assert_eq!(map.preferred().unwrap().addr, a2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn relay_entry_serializes() {
|
||||||
|
let entry = RelayEntry {
|
||||||
|
name: "test".into(),
|
||||||
|
addr: "10.0.0.1:4433".parse().unwrap(),
|
||||||
|
region: Some("us-east".into()),
|
||||||
|
rtt_ms: Some(42),
|
||||||
|
last_probed: Some(Instant::now()),
|
||||||
|
reachable: true,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&entry).unwrap();
|
||||||
|
assert!(json.contains("test"));
|
||||||
|
assert!(json.contains("us-east"));
|
||||||
|
assert!(json.contains("42"));
|
||||||
|
// last_probed is #[serde(skip)]
|
||||||
|
assert!(!json.contains("last_probed"));
|
||||||
|
}
|
||||||
|
}
|
||||||
1445
crates/wzp-client/src/stun.rs
Normal file
1445
crates/wzp-client/src/stun.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -72,8 +72,7 @@ fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
|
|||||||
/// decoder, pushes frames through the pipeline, and collects statistics.
|
/// decoder, pushes frames through the pipeline, and collects statistics.
|
||||||
/// Combinations where `target_depth > max_depth` are skipped.
|
/// Combinations where `target_depth > max_depth` are skipped.
|
||||||
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
|
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
|
||||||
let frames_per_config =
|
let frames_per_config = (config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||||
(config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
use std::net::{Ipv4Addr, SocketAddr};
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use wzp_client::dual_path::{race, PeerCandidates, WinningPath};
|
use wzp_client::dual_path::{PeerCandidates, WinningPath, race};
|
||||||
use wzp_client::reflect::Role;
|
use wzp_client::reflect::Role;
|
||||||
use wzp_transport::{create_endpoint, server_config};
|
use wzp_transport::{create_endpoint, server_config};
|
||||||
|
|
||||||
@@ -113,17 +113,27 @@ async fn dual_path_direct_wins_on_loopback() {
|
|||||||
PeerCandidates {
|
PeerCandidates {
|
||||||
reflexive: Some(acceptor_listen_addr),
|
reflexive: Some(acceptor_listen_addr),
|
||||||
local: Vec::new(),
|
local: Vec::new(),
|
||||||
|
mapped: None,
|
||||||
},
|
},
|
||||||
relay_addr,
|
relay_addr,
|
||||||
"test-room".into(),
|
"test-room".into(),
|
||||||
"call-test".into(),
|
"call-test".into(),
|
||||||
|
None, // own_reflexive: not needed in tests
|
||||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||||
|
None, // Phase 7: no IPv6 endpoint in tests
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("race must succeed");
|
.expect("race must succeed");
|
||||||
|
|
||||||
assert!(result.direct_transport.is_some(), "direct transport should be available");
|
assert!(
|
||||||
assert_eq!(result.local_winner, WinningPath::Direct, "direct should win on loopback");
|
result.direct_transport.is_some(),
|
||||||
|
"direct transport should be available"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result.local_winner,
|
||||||
|
WinningPath::Direct,
|
||||||
|
"direct should win on loopback"
|
||||||
|
);
|
||||||
|
|
||||||
// Cancel the acceptor accept task so the test finishes.
|
// Cancel the acceptor accept task so the test finishes.
|
||||||
acceptor_accept_task.abort();
|
acceptor_accept_task.abort();
|
||||||
@@ -155,16 +165,22 @@ async fn dual_path_relay_wins_when_direct_is_dead() {
|
|||||||
PeerCandidates {
|
PeerCandidates {
|
||||||
reflexive: Some(dead_peer),
|
reflexive: Some(dead_peer),
|
||||||
local: Vec::new(),
|
local: Vec::new(),
|
||||||
|
mapped: None,
|
||||||
},
|
},
|
||||||
relay_addr,
|
relay_addr,
|
||||||
"test-room".into(),
|
"test-room".into(),
|
||||||
"call-test".into(),
|
"call-test".into(),
|
||||||
|
None, // own_reflexive: not needed in tests
|
||||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||||
|
None, // Phase 7: no IPv6 endpoint in tests
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("race must succeed via relay fallback");
|
.expect("race must succeed via relay fallback");
|
||||||
|
|
||||||
assert!(result.relay_transport.is_some(), "relay transport should be available");
|
assert!(
|
||||||
|
result.relay_transport.is_some(),
|
||||||
|
"relay transport should be available"
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result.local_winner,
|
result.local_winner,
|
||||||
WinningPath::Relay,
|
WinningPath::Relay,
|
||||||
@@ -193,11 +209,14 @@ async fn dual_path_errors_cleanly_when_both_paths_dead() {
|
|||||||
PeerCandidates {
|
PeerCandidates {
|
||||||
reflexive: Some(dead_peer),
|
reflexive: Some(dead_peer),
|
||||||
local: Vec::new(),
|
local: Vec::new(),
|
||||||
|
mapped: None,
|
||||||
},
|
},
|
||||||
dead_relay,
|
dead_relay,
|
||||||
"test-room".into(),
|
"test-room".into(),
|
||||||
"call-test".into(),
|
"call-test".into(),
|
||||||
|
None, // own_reflexive: not needed in tests
|
||||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||||
|
None, // Phase 7: no IPv6 endpoint in tests
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let elapsed = start.elapsed();
|
let elapsed = start.elapsed();
|
||||||
|
|||||||
@@ -6,12 +6,12 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use wzp_proto::packet::MediaPacket;
|
use wzp_proto::packet::MediaPacket;
|
||||||
use wzp_proto::traits::{MediaTransport, PathQuality};
|
use wzp_proto::traits::{MediaTransport, PathQuality};
|
||||||
use wzp_proto::{SignalMessage, TransportError};
|
use wzp_proto::{SignalMessage, TransportError, default_signal_version};
|
||||||
|
|
||||||
/// A mock transport backed by two mpsc channels (one per direction).
|
/// A mock transport backed by two mpsc channels (one per direction).
|
||||||
///
|
///
|
||||||
@@ -83,11 +83,15 @@ async fn full_handshake_both_sides_derive_same_session() {
|
|||||||
|
|
||||||
// Run client and relay handshakes concurrently.
|
// Run client and relay handshakes concurrently.
|
||||||
let (client_result, relay_result) = tokio::join!(
|
let (client_result, relay_result) = tokio::join!(
|
||||||
wzp_client::handshake::perform_handshake(client_transport_clone.as_ref(), &client_seed, None),
|
wzp_client::handshake::perform_handshake(
|
||||||
|
client_transport_clone.as_ref(),
|
||||||
|
&client_seed,
|
||||||
|
None
|
||||||
|
),
|
||||||
wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed),
|
wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut client_session = client_result.expect("client handshake should succeed");
|
let client_hs = client_result.expect("client handshake should succeed");
|
||||||
let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) =
|
let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) =
|
||||||
relay_result.expect("relay handshake should succeed");
|
relay_result.expect("relay handshake should succeed");
|
||||||
|
|
||||||
@@ -95,31 +99,53 @@ async fn full_handshake_both_sides_derive_same_session() {
|
|||||||
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
|
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
|
||||||
|
|
||||||
// Verify both sides can communicate: client encrypts, relay decrypts.
|
// Verify both sides can communicate: client encrypts, relay decrypts.
|
||||||
let header = b"test-header";
|
// encrypt/decrypt derive nonces from MediaHeader.seq, so we need valid headers.
|
||||||
|
use wzp_proto::packet::MediaHeader;
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
let make_hdr = |seq: u32| {
|
||||||
|
let h = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq,
|
||||||
|
timestamp: seq.wrapping_mul(20),
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut b = Vec::new();
|
||||||
|
h.write_to(&mut b);
|
||||||
|
b
|
||||||
|
};
|
||||||
|
|
||||||
|
let header = make_hdr(0);
|
||||||
let plaintext = b"hello from client to relay";
|
let plaintext = b"hello from client to relay";
|
||||||
|
|
||||||
|
let mut client_session = client_hs.session;
|
||||||
let mut ciphertext = Vec::new();
|
let mut ciphertext = Vec::new();
|
||||||
client_session
|
client_session
|
||||||
.encrypt(header, plaintext, &mut ciphertext)
|
.encrypt(&header, plaintext, &mut ciphertext)
|
||||||
.expect("client encrypt should succeed");
|
.expect("client encrypt should succeed");
|
||||||
|
|
||||||
let mut decrypted = Vec::new();
|
let mut decrypted = Vec::new();
|
||||||
relay_session
|
relay_session
|
||||||
.decrypt(header, &ciphertext, &mut decrypted)
|
.decrypt(&header, &ciphertext, &mut decrypted)
|
||||||
.expect("relay decrypt should succeed");
|
.expect("relay decrypt should succeed");
|
||||||
|
|
||||||
assert_eq!(&decrypted[..], plaintext);
|
assert_eq!(&decrypted[..], plaintext);
|
||||||
|
|
||||||
// Verify reverse direction: relay encrypts, client decrypts.
|
// Verify reverse direction: relay encrypts, client decrypts.
|
||||||
|
let header2 = make_hdr(0); // relay's send_seq starts at 0
|
||||||
let plaintext2 = b"hello from relay to client";
|
let plaintext2 = b"hello from relay to client";
|
||||||
let mut ciphertext2 = Vec::new();
|
let mut ciphertext2 = Vec::new();
|
||||||
relay_session
|
relay_session
|
||||||
.encrypt(header, plaintext2, &mut ciphertext2)
|
.encrypt(&header2, plaintext2, &mut ciphertext2)
|
||||||
.expect("relay encrypt should succeed");
|
.expect("relay encrypt should succeed");
|
||||||
|
|
||||||
let mut decrypted2 = Vec::new();
|
let mut decrypted2 = Vec::new();
|
||||||
client_session
|
client_session
|
||||||
.decrypt(header, &ciphertext2, &mut decrypted2)
|
.decrypt(&header2, &ciphertext2, &mut decrypted2)
|
||||||
.expect("client decrypt should succeed");
|
.expect("client decrypt should succeed");
|
||||||
|
|
||||||
assert_eq!(&decrypted2[..], plaintext2);
|
assert_eq!(&decrypted2[..], plaintext2);
|
||||||
@@ -147,11 +173,15 @@ async fn handshake_rejects_tampered_signature() {
|
|||||||
let bad_signature = kx.sign(b"wrong-data-intentionally");
|
let bad_signature = kx.sign(b"wrong-data-intentionally");
|
||||||
|
|
||||||
let offer = SignalMessage::CallOffer {
|
let offer = SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature: bad_signature,
|
signature: bad_signature,
|
||||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||||
alias: None,
|
alias: None,
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![],
|
||||||
};
|
};
|
||||||
client_transport_clone
|
client_transport_clone
|
||||||
.send_signal(&offer)
|
.send_signal(&offer)
|
||||||
@@ -175,3 +205,42 @@ async fn handshake_rejects_tampered_signature() {
|
|||||||
Ok(_) => panic!("relay should reject tampered signature"),
|
Ok(_) => panic!("relay should reject tampered signature"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_receives_protocol_version_mismatch() {
|
||||||
|
let (client_transport, relay_transport) = MockTransport::pair();
|
||||||
|
|
||||||
|
let client_seed = [0xAA_u8; 32];
|
||||||
|
|
||||||
|
// Spawn a fake relay that sends ProtocolVersionMismatch.
|
||||||
|
let relay_clone = Arc::clone(&relay_transport);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Wait for the client's CallOffer.
|
||||||
|
let offer = relay_clone.recv_signal().await.unwrap().unwrap();
|
||||||
|
assert!(matches!(offer, SignalMessage::CallOffer { .. }));
|
||||||
|
|
||||||
|
// Respond with ProtocolVersionMismatch.
|
||||||
|
let mismatch = SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
|
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
|
||||||
|
server_supported: vec![3],
|
||||||
|
},
|
||||||
|
call_id: None,
|
||||||
|
};
|
||||||
|
relay_clone.send_signal(&mismatch).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let result =
|
||||||
|
wzp_client::handshake::perform_handshake(client_transport.as_ref(), &client_seed, None)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(wzp_client::handshake::HandshakeError::ProtocolVersionMismatch {
|
||||||
|
server_supported,
|
||||||
|
}) => {
|
||||||
|
assert_eq!(server_supported, vec![3]);
|
||||||
|
}
|
||||||
|
Err(other) => panic!("expected ProtocolVersionMismatch, got: {other:?}"),
|
||||||
|
Ok(_) => panic!("expected handshake to fail with ProtocolVersionMismatch"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -83,8 +83,12 @@ fn long_session_no_drift() {
|
|||||||
println!(
|
println!(
|
||||||
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
stats.underruns,
|
||||||
stats.packets_late, stats.packets_lost,
|
stats.overruns,
|
||||||
|
stats.current_depth,
|
||||||
|
stats.max_depth_seen,
|
||||||
|
stats.packets_late,
|
||||||
|
stats.packets_lost,
|
||||||
);
|
);
|
||||||
|
|
||||||
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
|
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
|
||||||
@@ -123,7 +127,7 @@ fn long_session_with_simulated_loss() {
|
|||||||
|
|
||||||
for (j, pkt) in batch.into_iter().enumerate() {
|
for (j, pkt) in batch.into_iter().enumerate() {
|
||||||
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
|
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
|
||||||
if !pkt.header.is_repair && i % 20 == 0 && j == 0 {
|
if !pkt.header.is_repair() && i % 20 == 0 && j == 0 {
|
||||||
continue; // drop this packet
|
continue; // drop this packet
|
||||||
}
|
}
|
||||||
decoder.ingest(pkt);
|
decoder.ingest(pkt);
|
||||||
@@ -139,8 +143,12 @@ fn long_session_with_simulated_loss() {
|
|||||||
println!(
|
println!(
|
||||||
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
stats.underruns,
|
||||||
stats.packets_late, stats.packets_lost,
|
stats.overruns,
|
||||||
|
stats.current_depth,
|
||||||
|
stats.max_depth_seen,
|
||||||
|
stats.packets_late,
|
||||||
|
stats.packets_lost,
|
||||||
);
|
);
|
||||||
|
|
||||||
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
|
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
|
||||||
@@ -150,6 +158,65 @@ fn long_session_with_simulated_loss() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify that `MediaHeader::timestamp` continues monotonically across
|
||||||
|
/// rekey boundaries. Rekey is a crypto-layer operation (key material
|
||||||
|
/// rotation) and must not reset or interfere with framing state.
|
||||||
|
///
|
||||||
|
/// We simulate a 3000-frame session with two conceptual rekeys at frames
|
||||||
|
/// 1000 and 2000. The encoder's timestamp counter must advance
|
||||||
|
/// monotonically throughout.
|
||||||
|
#[test]
|
||||||
|
fn rekey_timestamp_monotonic() {
|
||||||
|
let config = test_config();
|
||||||
|
let mut encoder = CallEncoder::new(&config);
|
||||||
|
|
||||||
|
let mut timestamps = Vec::new();
|
||||||
|
|
||||||
|
// Phase 1: before first rekey
|
||||||
|
for i in 0..1000 {
|
||||||
|
let pcm = sine_frame(i);
|
||||||
|
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||||
|
for pkt in packets {
|
||||||
|
timestamps.push(pkt.header.timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: between first and second rekey
|
||||||
|
for i in 1000..2000 {
|
||||||
|
let pcm = sine_frame(i);
|
||||||
|
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||||
|
for pkt in packets {
|
||||||
|
timestamps.push(pkt.header.timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 3: after second rekey
|
||||||
|
for i in 2000..3000 {
|
||||||
|
let pcm = sine_frame(i);
|
||||||
|
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||||
|
for pkt in packets {
|
||||||
|
timestamps.push(pkt.header.timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert strict monotonicity (non-decreasing) across all three phases.
|
||||||
|
for window in timestamps.windows(2) {
|
||||||
|
assert!(
|
||||||
|
window[1] >= window[0],
|
||||||
|
"timestamp not monotonic across rekey boundary: {} -> {}",
|
||||||
|
window[0],
|
||||||
|
window[1]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity: we should have collected at least 3000 timestamps.
|
||||||
|
assert!(
|
||||||
|
timestamps.len() >= 3000,
|
||||||
|
"expected >= 3000 timestamps, got {}",
|
||||||
|
timestamps.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Verify that the jitter buffer's decoded-frame count is consistent with its
|
/// Verify that the jitter buffer's decoded-frame count is consistent with its
|
||||||
/// own internal statistics over a long session.
|
/// own internal statistics over a long session.
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -116,6 +116,14 @@ impl AudioEncoder for AdaptiveEncoder {
|
|||||||
fn set_dtx(&mut self, enabled: bool) {
|
fn set_dtx(&mut self, enabled: bool) {
|
||||||
self.opus.set_dtx(enabled);
|
self.opus.set_dtx(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||||
|
self.opus.set_expected_loss(loss_pct);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_dred_duration(&mut self, frames: u8) {
|
||||||
|
self.opus.set_dred_duration(frames);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── AdaptiveDecoder ─────────────────────────────────────────────────────────
|
// ─── AdaptiveDecoder ─────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -114,11 +114,7 @@ impl EchoCanceller {
|
|||||||
/// Number of delayed samples available to release.
|
/// Number of delayed samples available to release.
|
||||||
fn delay_available(&self) -> usize {
|
fn delay_available(&self) -> usize {
|
||||||
let buffered = self.delay_write - self.delay_read;
|
let buffered = self.delay_write - self.delay_read;
|
||||||
if buffered > self.delay_samples {
|
buffered.saturating_sub(self.delay_samples)
|
||||||
buffered - self.delay_samples
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a near-end (microphone) frame, removing the estimated echo.
|
/// Process a near-end (microphone) frame, removing the estimated echo.
|
||||||
@@ -161,8 +157,8 @@ impl EchoCanceller {
|
|||||||
let mut sum_near_sq: f64 = 0.0;
|
let mut sum_near_sq: f64 = 0.0;
|
||||||
let mut sum_err_sq: f64 = 0.0;
|
let mut sum_err_sq: f64 = 0.0;
|
||||||
|
|
||||||
for i in 0..n {
|
for (i, sample) in nearend.iter_mut().enumerate() {
|
||||||
let near_f = nearend[i] as f32;
|
let near_f = *sample as f32;
|
||||||
|
|
||||||
// Position of far-end "now" for this near-end sample.
|
// Position of far-end "now" for this near-end sample.
|
||||||
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
||||||
@@ -190,7 +186,7 @@ impl EchoCanceller {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let out = error.clamp(-32768.0, 32767.0);
|
let out = error.clamp(-32768.0, 32767.0);
|
||||||
nearend[i] = out as i16;
|
*sample = out as i16;
|
||||||
|
|
||||||
sum_near_sq += (near_f as f64).powi(2);
|
sum_near_sq += (near_f as f64).powi(2);
|
||||||
sum_err_sq += (out as f64).powi(2);
|
sum_err_sq += (out as f64).powi(2);
|
||||||
@@ -325,7 +321,10 @@ mod tests {
|
|||||||
// Feed 960 samples (= delay amount). No samples released yet.
|
// Feed 960 samples (= delay amount). No samples released yet.
|
||||||
aec.feed_farend(&vec![1i16; 960]);
|
aec.feed_farend(&vec![1i16; 960]);
|
||||||
// far_buf should still be all zeros (nothing released).
|
// far_buf should still be all zeros (nothing released).
|
||||||
assert!(aec.far_buf.iter().all(|&s| s == 0.0), "nothing should be released yet");
|
assert!(
|
||||||
|
aec.far_buf.iter().all(|&s| s == 0.0),
|
||||||
|
"nothing should be released yet"
|
||||||
|
);
|
||||||
|
|
||||||
// Feed 480 more. 480 should be released to far_buf.
|
// Feed 480 more. 480 should be released to far_buf.
|
||||||
aec.feed_farend(&vec![2i16; 480]);
|
aec.feed_farend(&vec![2i16; 480]);
|
||||||
|
|||||||
@@ -211,9 +211,6 @@ mod tests {
|
|||||||
fn agc_gain_db_at_unity() {
|
fn agc_gain_db_at_unity() {
|
||||||
let agc = AutoGainControl::new();
|
let agc = AutoGainControl::new();
|
||||||
let db = agc.current_gain_db();
|
let db = agc.current_gain_db();
|
||||||
assert!(
|
assert!(db.abs() < 0.01, "expected ~0 dB at unity gain, got {db}");
|
||||||
db.abs() < 0.01,
|
|
||||||
"expected ~0 dB at unity gain, got {db}"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ impl Codec2Decoder {
|
|||||||
|
|
||||||
/// Number of compressed bytes per frame.
|
/// Number of compressed bytes per frame.
|
||||||
fn bytes_per_frame(&self) -> usize {
|
fn bytes_per_frame(&self) -> usize {
|
||||||
(self.inner.bits_per_frame() + 7) / 8
|
self.inner.bits_per_frame().div_ceil(8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ impl Codec2Encoder {
|
|||||||
|
|
||||||
/// Number of compressed bytes per frame.
|
/// Number of compressed bytes per frame.
|
||||||
fn bytes_per_frame(&self) -> usize {
|
fn bytes_per_frame(&self) -> usize {
|
||||||
(self.inner.bits_per_frame() + 7) / 8
|
self.inner.bits_per_frame().div_ceil(8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ impl NoiseSupressor {
|
|||||||
|
|
||||||
// f32 → i16 with clamping
|
// f32 → i16 with clamping
|
||||||
for (i, &val) in output.iter().enumerate() {
|
for (i, &val) in output.iter().enumerate() {
|
||||||
let clamped = val.max(-32768.0).min(32767.0);
|
let clamped = val.clamp(-32768.0, 32767.0);
|
||||||
pcm[offset + i] = clamped as i16;
|
pcm[offset + i] = clamped as i16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
let original_len = pcm.len();
|
let original_len = pcm.len();
|
||||||
ns.process(&mut pcm);
|
ns.process(&mut pcm);
|
||||||
assert_eq!(pcm.len(), original_len, "output length must match input length");
|
assert_eq!(
|
||||||
|
pcm.len(),
|
||||||
|
original_len,
|
||||||
|
"output length must match input length"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -71,9 +71,8 @@ impl DecoderHandle {
|
|||||||
"opus_decoder_create failed: err={error}"
|
"opus_decoder_create failed: err={error}"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
let inner = NonNull::new(ptr).ok_or_else(|| {
|
let inner = NonNull::new(ptr)
|
||||||
CodecError::DecodeFailed("opus_decoder_create returned null".into())
|
.ok_or_else(|| CodecError::DecodeFailed("opus_decoder_create returned null".into()))?;
|
||||||
})?;
|
|
||||||
Ok(Self { inner })
|
Ok(Self { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,11 +256,7 @@ impl DredDecoderHandle {
|
|||||||
/// The `dred_end` output is the silence gap at the tail of the DRED
|
/// The `dred_end` output is the silence gap at the tail of the DRED
|
||||||
/// window; we subtract it from the total offset to give callers the
|
/// window; we subtract it from the total offset to give callers the
|
||||||
/// truly usable sample count.
|
/// truly usable sample count.
|
||||||
pub fn parse_into(
|
pub fn parse_into(&mut self, state: &mut DredState, packet: &[u8]) -> Result<i32, CodecError> {
|
||||||
&mut self,
|
|
||||||
state: &mut DredState,
|
|
||||||
packet: &[u8],
|
|
||||||
) -> Result<i32, CodecError> {
|
|
||||||
if packet.is_empty() {
|
if packet.is_empty() {
|
||||||
state.samples_available = 0;
|
state.samples_available = 0;
|
||||||
return Ok(0);
|
return Ok(0);
|
||||||
@@ -545,7 +540,10 @@ mod tests {
|
|||||||
// to our sine wave because we fed a cold decoder only one warmup
|
// to our sine wave because we fed a cold decoder only one warmup
|
||||||
// frame, but it should still produce non-silent speech-like output
|
// frame, but it should still produce non-silent speech-like output
|
||||||
// since the DRED state was parsed from real speech content.
|
// since the DRED state was parsed from real speech content.
|
||||||
let energy: u64 = recon_pcm.iter().map(|&s| (s as i32).unsigned_abs() as u64).sum();
|
let energy: u64 = recon_pcm
|
||||||
|
.iter()
|
||||||
|
.map(|&s| (s as i32).unsigned_abs() as u64)
|
||||||
|
.sum();
|
||||||
assert!(
|
assert!(
|
||||||
energy > 0,
|
energy > 0,
|
||||||
"reconstructed audio has zero total energy — DRED reconstruction produced silence"
|
"reconstructed audio has zero total energy — DRED reconstruction produced silence"
|
||||||
|
|||||||
@@ -53,10 +53,7 @@ pub fn set_dred_verbose_logs(enabled: bool) {
|
|||||||
/// The returned encoder accepts 48 kHz mono PCM regardless of the active
|
/// The returned encoder accepts 48 kHz mono PCM regardless of the active
|
||||||
/// codec; resampling is handled internally when Codec2 is selected.
|
/// codec; resampling is handled internally when Codec2 is selected.
|
||||||
pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
||||||
Box::new(
|
Box::new(AdaptiveEncoder::new(profile).expect("failed to create adaptive encoder"))
|
||||||
AdaptiveEncoder::new(profile)
|
|
||||||
.expect("failed to create adaptive encoder"),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an adaptive decoder starting at the given quality profile.
|
/// Create an adaptive decoder starting at the given quality profile.
|
||||||
@@ -64,10 +61,7 @@ pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
|||||||
/// The returned decoder always produces 48 kHz mono PCM; upsampling from
|
/// The returned decoder always produces 48 kHz mono PCM; upsampling from
|
||||||
/// Codec2's native 8 kHz is handled internally.
|
/// Codec2's native 8 kHz is handled internally.
|
||||||
pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> {
|
pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> {
|
||||||
Box::new(
|
Box::new(AdaptiveDecoder::new(profile).expect("failed to create adaptive decoder"))
|
||||||
AdaptiveDecoder::new(profile)
|
|
||||||
.expect("failed to create adaptive decoder"),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -82,6 +76,10 @@ mod codec2_tests {
|
|||||||
fec_ratio: 0.5,
|
fec_ratio: 0.5,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
priority_mode: wzp_proto::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,7 +208,10 @@ mod codec2_tests {
|
|||||||
|
|
||||||
let mut pcm_out_c2 = vec![0i16; 1920];
|
let mut pcm_out_c2 = vec![0i16; 1920];
|
||||||
let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap();
|
let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap();
|
||||||
assert_eq!(samples_c2, 1920, "should get 1920 samples at 48kHz after upsample");
|
assert_eq!(
|
||||||
|
samples_c2, 1920,
|
||||||
|
"should get 1920 samples at 48kHz after upsample"
|
||||||
|
);
|
||||||
|
|
||||||
// Step 3: Switch back to Opus.
|
// Step 3: Switch back to Opus.
|
||||||
enc.set_profile(QualityProfile::GOOD).unwrap();
|
enc.set_profile(QualityProfile::GOOD).unwrap();
|
||||||
|
|||||||
@@ -14,8 +14,9 @@
|
|||||||
//! networks; short window keeps decoder CPU modest.
|
//! networks; short window keeps decoder CPU modest.
|
||||||
//! - Normal tiers (Opus 16k/24k): 200 ms — balanced baseline covering common
|
//! - Normal tiers (Opus 16k/24k): 200 ms — balanced baseline covering common
|
||||||
//! VoIP loss patterns (20–150 ms bursts from wifi roam, transient congestion).
|
//! VoIP loss patterns (20–150 ms bursts from wifi roam, transient congestion).
|
||||||
//! - Degraded tier (Opus 6k): 500 ms — users on 6k are by definition on a
|
//! - Degraded tier (Opus 6k): 1040 ms — users on 6k are by definition on a
|
||||||
//! bad link; longer DRED buys maximum burst resilience where it matters.
|
//! bad link; the maximum libopus DRED window buys the best burst resilience
|
||||||
|
//! where it matters. The RDO-VAE naturally degrades quality at longer offsets.
|
||||||
//!
|
//!
|
||||||
//! # Why the 15% packet loss floor
|
//! # Why the 15% packet loss floor
|
||||||
//!
|
//!
|
||||||
@@ -78,10 +79,19 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
|
|||||||
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10,
|
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10,
|
||||||
// Normal tiers — balanced baseline.
|
// Normal tiers — balanced baseline.
|
||||||
CodecId::Opus16k | CodecId::Opus24k => 20,
|
CodecId::Opus16k | CodecId::Opus24k => 20,
|
||||||
// Degraded tier — maximum burst resilience.
|
// Degraded tier — maximum burst resilience. 104 × 10 ms = 1040 ms,
|
||||||
CodecId::Opus6k => 50,
|
// the highest value libopus 1.5 supports. Users on 6k are on a bad
|
||||||
// Non-Opus (Codec2 / CN): DRED is N/A.
|
// link by definition; the RDO-VAE naturally degrades quality at longer
|
||||||
CodecId::Codec2_1200 | CodecId::Codec2_3200 | CodecId::ComfortNoise => 0,
|
// offsets, so the extra window costs only ~1-2 kbps additional overhead
|
||||||
|
// while buying substantially better burst resilience (up from 500 ms).
|
||||||
|
CodecId::Opus6k => 104,
|
||||||
|
// Non-Opus (Codec2 / CN / video): DRED is N/A.
|
||||||
|
CodecId::Codec2_1200
|
||||||
|
| CodecId::Codec2_3200
|
||||||
|
| CodecId::ComfortNoise
|
||||||
|
| CodecId::H264Baseline
|
||||||
|
| CodecId::H265Main
|
||||||
|
| CodecId::Av1Main => 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +101,7 @@ pub fn dred_duration_for(codec: CodecId) -> u8 {
|
|||||||
/// mode; unset or empty leaves DRED enabled.
|
/// mode; unset or empty leaves DRED enabled.
|
||||||
fn read_legacy_fec_env() -> bool {
|
fn read_legacy_fec_env() -> bool {
|
||||||
match std::env::var(LEGACY_FEC_ENV) {
|
match std::env::var(LEGACY_FEC_ENV) {
|
||||||
Ok(v) => !v.is_empty() && v != "0" && v.to_ascii_lowercase() != "false",
|
Ok(v) => !v.is_empty() && v != "0" && !v.eq_ignore_ascii_case("false"),
|
||||||
Err(_) => false,
|
Err(_) => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -242,7 +252,7 @@ impl OpusEncoder {
|
|||||||
let clamped = if self.legacy_fec_mode {
|
let clamped = if self.legacy_fec_mode {
|
||||||
loss_pct.min(100)
|
loss_pct.min(100)
|
||||||
} else {
|
} else {
|
||||||
loss_pct.max(DRED_LOSS_FLOOR_PCT).min(100)
|
loss_pct.clamp(DRED_LOSS_FLOOR_PCT, 100)
|
||||||
};
|
};
|
||||||
let _ = self.inner.set_packet_loss(clamped);
|
let _ = self.inner.set_packet_loss(clamped);
|
||||||
}
|
}
|
||||||
@@ -327,13 +337,25 @@ impl AudioEncoder for OpusEncoder {
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let mode = if enabled { InbandFec::Mode1 } else { InbandFec::Off };
|
let mode = if enabled {
|
||||||
|
InbandFec::Mode1
|
||||||
|
} else {
|
||||||
|
InbandFec::Off
|
||||||
|
};
|
||||||
let _ = self.inner.set_inband_fec(mode);
|
let _ = self.inner.set_inband_fec(mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_dtx(&mut self, enabled: bool) {
|
fn set_dtx(&mut self, enabled: bool) {
|
||||||
let _ = self.inner.set_dtx(enabled);
|
let _ = self.inner.set_dtx(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||||
|
OpusEncoder::set_expected_loss(self, loss_pct);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_dred_duration(&mut self, frames: u8) {
|
||||||
|
OpusEncoder::set_dred_duration(self, frames);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -389,8 +411,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn dred_duration_for_degraded_tier_is_500ms() {
|
fn dred_duration_for_degraded_tier_is_1040ms() {
|
||||||
assert_eq!(dred_duration_for(CodecId::Opus6k), 50);
|
assert_eq!(dred_duration_for(CodecId::Opus6k), 104);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
|
|||||||
let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5)
|
let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5)
|
||||||
let beta_denom = bessel_i0(KAISER_BETA);
|
let beta_denom = bessel_i0(KAISER_BETA);
|
||||||
|
|
||||||
for i in 0..FIR_TAPS {
|
for (i, slot) in kernel.iter_mut().enumerate() {
|
||||||
// Sinc
|
// Sinc
|
||||||
let n = i as f64 - m / 2.0;
|
let n = i as f64 - m / 2.0;
|
||||||
let sinc = if n.abs() < 1e-12 {
|
let sinc = if n.abs() < 1e-12 {
|
||||||
@@ -61,7 +61,7 @@ fn build_fir_kernel() -> [f64; FIR_TAPS] {
|
|||||||
let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1]
|
let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1]
|
||||||
let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom;
|
let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom;
|
||||||
|
|
||||||
kernel[i] = sinc * kaiser;
|
*slot = sinc * kaiser;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalise to unity DC gain.
|
// Normalise to unity DC gain.
|
||||||
@@ -129,8 +129,7 @@ impl Downsampler48to8 {
|
|||||||
|
|
||||||
// Update history: keep the last (FIR_TAPS - 1) samples from work.
|
// Update history: keep the last (FIR_TAPS - 1) samples from work.
|
||||||
if work.len() >= hist_len {
|
if work.len() >= hist_len {
|
||||||
self.history
|
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
|
||||||
} else {
|
} else {
|
||||||
// Input was shorter than history — shift.
|
// Input was shorter than history — shift.
|
||||||
let shift = hist_len - work.len();
|
let shift = hist_len - work.len();
|
||||||
@@ -181,9 +180,7 @@ impl Upsampler8to48 {
|
|||||||
work.extend_from_slice(&self.history);
|
work.extend_from_slice(&self.history);
|
||||||
for &s in input {
|
for &s in input {
|
||||||
work.push(s as f64);
|
work.push(s as f64);
|
||||||
for _ in 1..RATIO {
|
work.resize(work.len() + (RATIO - 1), 0.0f64);
|
||||||
work.push(0.0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let out_len = stuffed_len;
|
let out_len = stuffed_len;
|
||||||
@@ -209,8 +206,7 @@ impl Upsampler8to48 {
|
|||||||
|
|
||||||
// Update history.
|
// Update history.
|
||||||
if work.len() >= hist_len {
|
if work.len() >= hist_len {
|
||||||
self.history
|
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
|
||||||
} else {
|
} else {
|
||||||
let shift = hist_len - work.len();
|
let shift = hist_len - work.len();
|
||||||
self.history.copy_within(shift.., 0);
|
self.history.copy_within(shift.., 0);
|
||||||
|
|||||||
@@ -151,7 +151,10 @@ mod tests {
|
|||||||
for _ in 0..4 {
|
for _ in 0..4 {
|
||||||
det.is_silent(&silence);
|
det.is_silent(&silence);
|
||||||
}
|
}
|
||||||
assert!(det.is_silent(&silence), "should be suppressing after hangover");
|
assert!(
|
||||||
|
det.is_silent(&silence),
|
||||||
|
"should be suppressing after hangover"
|
||||||
|
);
|
||||||
|
|
||||||
// Speech arrives — should immediately stop suppressing.
|
// Speech arrives — should immediately stop suppressing.
|
||||||
assert!(!det.is_silent(&speech));
|
assert!(!det.is_silent(&speech));
|
||||||
@@ -165,10 +168,16 @@ mod tests {
|
|||||||
cn.generate(&mut pcm);
|
cn.generate(&mut pcm);
|
||||||
|
|
||||||
// At least some samples should be non-zero.
|
// At least some samples should be non-zero.
|
||||||
assert!(pcm.iter().any(|&s| s != 0), "CN output should not be all zeros");
|
assert!(
|
||||||
|
pcm.iter().any(|&s| s != 0),
|
||||||
|
"CN output should not be all zeros"
|
||||||
|
);
|
||||||
|
|
||||||
// All samples should be within [-50, 50].
|
// All samples should be within [-50, 50].
|
||||||
assert!(pcm.iter().all(|&s| s.abs() <= 50), "CN samples out of range");
|
assert!(
|
||||||
|
pcm.iter().all(|&s| s.abs() <= 50),
|
||||||
|
"CN samples out of range"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -179,11 +188,17 @@ mod tests {
|
|||||||
// Constant value: RMS of [v, v, v, ...] = |v|.
|
// Constant value: RMS of [v, v, v, ...] = |v|.
|
||||||
let pcm = vec![100i16; 100];
|
let pcm = vec![100i16; 100];
|
||||||
let rms = SilenceDetector::rms(&pcm);
|
let rms = SilenceDetector::rms(&pcm);
|
||||||
assert!((rms - 100.0).abs() < 0.01, "RMS of constant 100 should be 100, got {rms}");
|
assert!(
|
||||||
|
(rms - 100.0).abs() < 0.01,
|
||||||
|
"RMS of constant 100 should be 100, got {rms}"
|
||||||
|
);
|
||||||
|
|
||||||
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
|
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
|
||||||
let rms2 = SilenceDetector::rms(&[3, 4]);
|
let rms2 = SilenceDetector::rms(&[3, 4]);
|
||||||
assert!((rms2 - 3.5355).abs() < 0.01, "RMS of [3,4] should be ~3.5355, got {rms2}");
|
assert!(
|
||||||
|
(rms2 - 3.5355).abs() < 0.01,
|
||||||
|
"RMS of [3,4] should be ~3.5355, got {rms2}"
|
||||||
|
);
|
||||||
|
|
||||||
// Empty buffer → 0.
|
// Empty buffer → 0.
|
||||||
assert_eq!(SilenceDetector::rms(&[]), 0.0);
|
assert_eq!(SilenceDetector::rms(&[]), 0.0);
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
//! Sliding window replay protection.
|
//! Sliding window replay protection.
|
||||||
//!
|
//!
|
||||||
//! Tracks seen sequence numbers using a bitmap. Window size is 1024 packets.
|
//! Tracks seen sequence numbers using a bitmap. Window size is configurable
|
||||||
//! Sequence numbers that are too old (more than WINDOW_SIZE behind the highest
|
//! at construction time. Sequence numbers that are too old (more than
|
||||||
//! seen) are rejected.
|
//! `window_size` behind the highest seen) are rejected.
|
||||||
|
|
||||||
use wzp_proto::CryptoError;
|
use wzp_proto::CryptoError;
|
||||||
|
|
||||||
/// Window size in packets.
|
|
||||||
const WINDOW_SIZE: u16 = 1024;
|
|
||||||
|
|
||||||
/// Sliding window anti-replay detector.
|
/// Sliding window anti-replay detector.
|
||||||
///
|
///
|
||||||
/// Uses a bitmap to track which sequence numbers have been seen within
|
/// Uses a bitmap to track which sequence numbers have been seen within
|
||||||
/// the current window. Handles u16 wrapping correctly.
|
/// the current window. Handles `u32` wrapping correctly.
|
||||||
pub struct AntiReplayWindow {
|
pub struct AntiReplayWindow {
|
||||||
|
/// Window size in packets.
|
||||||
|
window_size: u32,
|
||||||
/// Highest sequence number seen so far.
|
/// Highest sequence number seen so far.
|
||||||
highest: u16,
|
highest: u32,
|
||||||
/// Bitmap of seen packets. Bit i corresponds to (highest - i).
|
/// Bitmap of seen packets. Bit i corresponds to (highest - i).
|
||||||
bitmap: Vec<u64>,
|
bitmap: Vec<u64>,
|
||||||
/// Whether any packet has been received yet.
|
/// Whether any packet has been received yet.
|
||||||
@@ -23,21 +22,26 @@ pub struct AntiReplayWindow {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AntiReplayWindow {
|
impl AntiReplayWindow {
|
||||||
/// Number of u64 words needed for the bitmap.
|
/// Create a new anti-replay window with the default size of 1024 packets.
|
||||||
const BITMAP_WORDS: usize = (WINDOW_SIZE as usize + 63) / 64;
|
|
||||||
|
|
||||||
/// Create a new anti-replay window.
|
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
Self::with_window(1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new anti-replay window with a custom size.
|
||||||
|
pub fn with_window(size: usize) -> Self {
|
||||||
|
let window_size = size as u32;
|
||||||
|
let bitmap_words = (size + 63) / 64;
|
||||||
Self {
|
Self {
|
||||||
|
window_size,
|
||||||
highest: 0,
|
highest: 0,
|
||||||
bitmap: vec![0u64; Self::BITMAP_WORDS],
|
bitmap: vec![0u64; bitmap_words],
|
||||||
initialized: false,
|
initialized: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a sequence number is valid (not a replay, not too old).
|
/// Check if a sequence number is valid (not a replay, not too old).
|
||||||
/// If valid, marks it as seen.
|
/// If valid, marks it as seen.
|
||||||
pub fn check_and_update(&mut self, seq: u16) -> Result<(), CryptoError> {
|
pub fn check_and_update(&mut self, seq: u32) -> Result<(), CryptoError> {
|
||||||
if !self.initialized {
|
if !self.initialized {
|
||||||
self.initialized = true;
|
self.initialized = true;
|
||||||
self.highest = seq;
|
self.highest = seq;
|
||||||
@@ -52,17 +56,17 @@ impl AntiReplayWindow {
|
|||||||
return Err(CryptoError::ReplayDetected { seq });
|
return Err(CryptoError::ReplayDetected { seq });
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff < 0x8000 {
|
if diff < 0x8000_0000 {
|
||||||
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF])
|
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF_FFFF])
|
||||||
let shift = diff as usize;
|
let shift = diff as usize;
|
||||||
self.advance_window(shift);
|
self.advance_window(shift);
|
||||||
self.highest = seq;
|
self.highest = seq;
|
||||||
self.set_bit(0);
|
self.set_bit(0);
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
// seq is behind highest (wrapping-aware: diff in [0x8000, 0xFFFF])
|
// seq is behind highest (wrapping-aware: diff in [0x8000_0000, 0xFFFF_FFFF])
|
||||||
let behind = self.highest.wrapping_sub(seq) as usize;
|
let behind = self.highest.wrapping_sub(seq) as usize;
|
||||||
if behind >= WINDOW_SIZE as usize {
|
if behind >= self.window_size as usize {
|
||||||
return Err(CryptoError::ReplayDetected { seq });
|
return Err(CryptoError::ReplayDetected { seq });
|
||||||
}
|
}
|
||||||
if self.get_bit(behind) {
|
if self.get_bit(behind) {
|
||||||
@@ -75,7 +79,8 @@ impl AntiReplayWindow {
|
|||||||
|
|
||||||
/// Advance the window by `shift` positions (shift left = new bits at position 0).
|
/// Advance the window by `shift` positions (shift left = new bits at position 0).
|
||||||
fn advance_window(&mut self, shift: usize) {
|
fn advance_window(&mut self, shift: usize) {
|
||||||
if shift >= WINDOW_SIZE as usize {
|
let window_size = self.window_size as usize;
|
||||||
|
if shift >= window_size {
|
||||||
for word in &mut self.bitmap {
|
for word in &mut self.bitmap {
|
||||||
*word = 0;
|
*word = 0;
|
||||||
}
|
}
|
||||||
@@ -156,7 +161,11 @@ mod tests {
|
|||||||
fn sequential_accepted() {
|
fn sequential_accepted() {
|
||||||
let mut w = AntiReplayWindow::new();
|
let mut w = AntiReplayWindow::new();
|
||||||
for i in 0..200 {
|
for i in 0..200 {
|
||||||
assert!(w.check_and_update(i).is_ok(), "seq {} should be accepted", i);
|
assert!(
|
||||||
|
w.check_and_update(i).is_ok(),
|
||||||
|
"seq {} should be accepted",
|
||||||
|
i
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,11 +192,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn wrapping_works() {
|
fn wrapping_works() {
|
||||||
let mut w = AntiReplayWindow::new();
|
let mut w = AntiReplayWindow::new();
|
||||||
assert!(w.check_and_update(65530).is_ok());
|
assert!(w.check_and_update(0xFFFF_FFF0).is_ok());
|
||||||
assert!(w.check_and_update(65535).is_ok());
|
assert!(w.check_and_update(0xFFFF_FFFF).is_ok());
|
||||||
assert!(w.check_and_update(0).is_ok()); // wrapped
|
assert!(w.check_and_update(0).is_ok()); // wrapped
|
||||||
assert!(w.check_and_update(1).is_ok());
|
assert!(w.check_and_update(1).is_ok());
|
||||||
assert!(w.check_and_update(65535).is_err()); // duplicate
|
assert!(w.check_and_update(0xFFFF_FFFF).is_err()); // duplicate
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -201,4 +210,53 @@ mod tests {
|
|||||||
// Now 0 is 1024 behind 1024, which is at the boundary limit
|
// Now 0 is 1024 behind 1024, which is at the boundary limit
|
||||||
assert!(w.check_and_update(0).is_err()); // already seen or too old
|
assert!(w.check_and_update(0).is_err()); // already seen or too old
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_window_size() {
|
||||||
|
let mut w = AntiReplayWindow::with_window(64);
|
||||||
|
for i in 0..64 {
|
||||||
|
assert!(w.check_and_update(i).is_ok());
|
||||||
|
}
|
||||||
|
// seq 0 is now exactly at the boundary (64 behind 64)
|
||||||
|
assert!(w.check_and_update(0).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn video_burst_200_with_one_reorder() {
|
||||||
|
let mut w = AntiReplayWindow::with_window(1024);
|
||||||
|
// Simulate a 200-packet burst
|
||||||
|
for i in 0..200 {
|
||||||
|
assert!(
|
||||||
|
w.check_and_update(i).is_ok(),
|
||||||
|
"seq {} should be accepted",
|
||||||
|
i
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// One packet reordered (arrives late)
|
||||||
|
assert!(w.check_and_update(50).is_err(), "seq 50 is a duplicate");
|
||||||
|
// But a packet just behind the window should still be ok
|
||||||
|
assert!(w.check_and_update(199).is_err(), "seq 199 is a duplicate");
|
||||||
|
// Continue the burst
|
||||||
|
for i in 200..400 {
|
||||||
|
assert!(
|
||||||
|
w.check_and_update(i).is_ok(),
|
||||||
|
"seq {} should be accepted",
|
||||||
|
i
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn u32_high_range_works() {
|
||||||
|
let mut w = AntiReplayWindow::with_window(64);
|
||||||
|
let base = 1000u32;
|
||||||
|
assert!(w.check_and_update(base).is_ok());
|
||||||
|
assert!(w.check_and_update(base + 1).is_ok());
|
||||||
|
// 65 behind highest (base+1) is outside the 64-packet window
|
||||||
|
assert!(w.check_and_update(base.wrapping_sub(64)).is_err());
|
||||||
|
// 63 behind is inside
|
||||||
|
assert!(w.check_and_update(base.wrapping_sub(62)).is_ok());
|
||||||
|
// base itself is now a duplicate
|
||||||
|
assert!(w.check_and_update(base).is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ use ed25519_dalek::{Signer, SigningKey, Verifier, VerifyingKey};
|
|||||||
use hkdf::Hkdf;
|
use hkdf::Hkdf;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
|
||||||
use wzp_proto::{CryptoError, CryptoSession, KeyExchange};
|
use wzp_proto::{CryptoError, CryptoSession, KeyExchange};
|
||||||
|
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
||||||
|
|
||||||
use crate::session::ChaChaSession;
|
use crate::session::ChaChaSession;
|
||||||
|
|
||||||
@@ -18,10 +18,14 @@ use crate::session::ChaChaSession;
|
|||||||
pub struct WarzoneKeyExchange {
|
pub struct WarzoneKeyExchange {
|
||||||
/// Ed25519 signing key (identity).
|
/// Ed25519 signing key (identity).
|
||||||
signing_key: SigningKey,
|
signing_key: SigningKey,
|
||||||
/// X25519 static secret (derived from seed, used for identity encryption).
|
/// X25519 static secret derived from identity seed. Reserved for future
|
||||||
|
/// use in static-key federation authentication (not used in current
|
||||||
|
/// ephemeral-only handshake protocol).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
x25519_static_secret: StaticSecret,
|
x25519_static_secret: StaticSecret,
|
||||||
/// X25519 static public key.
|
/// X25519 static public key derived from identity seed. Reserved for
|
||||||
|
/// future use in static-key federation authentication (not used in
|
||||||
|
/// current ephemeral-only handshake protocol).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
x25519_static_public: X25519PublicKey,
|
x25519_static_public: X25519PublicKey,
|
||||||
/// Ephemeral X25519 secret for the current call (set by generate_ephemeral).
|
/// Ephemeral X25519 secret for the current call (set by generate_ephemeral).
|
||||||
@@ -91,11 +95,10 @@ impl KeyExchange for WarzoneKeyExchange {
|
|||||||
&self,
|
&self,
|
||||||
peer_ephemeral_pub: &[u8; 32],
|
peer_ephemeral_pub: &[u8; 32],
|
||||||
) -> Result<Box<dyn CryptoSession>, CryptoError> {
|
) -> Result<Box<dyn CryptoSession>, CryptoError> {
|
||||||
let secret = self
|
let secret = self.ephemeral_secret.as_ref().ok_or_else(|| {
|
||||||
.ephemeral_secret
|
CryptoError::Internal(
|
||||||
.as_ref()
|
"no ephemeral key generated; call generate_ephemeral first".into(),
|
||||||
.ok_or_else(|| {
|
)
|
||||||
CryptoError::Internal("no ephemeral key generated; call generate_ephemeral first".into())
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let peer_public = X25519PublicKey::from(*peer_ephemeral_pub);
|
let peer_public = X25519PublicKey::from(*peer_ephemeral_pub);
|
||||||
@@ -206,18 +209,34 @@ mod tests {
|
|||||||
let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap();
|
let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap();
|
||||||
let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap();
|
let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap();
|
||||||
|
|
||||||
// Verify they can communicate: Alice encrypts, Bob decrypts
|
// Verify they can communicate: Alice encrypts, Bob decrypts.
|
||||||
let header = b"call-header";
|
// Use a valid v2 MediaHeader — encrypt/decrypt now derive the nonce from
|
||||||
|
// header.seq and will reject raw byte slices shorter than WIRE_SIZE.
|
||||||
|
use wzp_proto::{CodecId, MediaHeader, MediaType};
|
||||||
|
let header = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq: 0,
|
||||||
|
timestamp: 0,
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
header.write_to(&mut header_bytes);
|
||||||
|
|
||||||
let plaintext = b"hello from alice";
|
let plaintext = b"hello from alice";
|
||||||
|
|
||||||
let mut ciphertext = Vec::new();
|
let mut ciphertext = Vec::new();
|
||||||
alice_session
|
alice_session
|
||||||
.encrypt(header, plaintext, &mut ciphertext)
|
.encrypt(&header_bytes, plaintext, &mut ciphertext)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut decrypted = Vec::new();
|
let mut decrypted = Vec::new();
|
||||||
bob_session
|
bob_session
|
||||||
.decrypt(header, &ciphertext, &mut decrypted)
|
.decrypt(&header_bytes, &ciphertext, &mut decrypted)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(&decrypted, plaintext);
|
assert_eq!(&decrypted, plaintext);
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ impl Seed {
|
|||||||
///
|
///
|
||||||
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
|
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
|
||||||
pub fn from_mnemonic(words: &str) -> Result<Self, String> {
|
pub fn from_mnemonic(words: &str) -> Result<Self, String> {
|
||||||
let mnemonic: bip39::Mnemonic = words.parse().map_err(|e| format!("invalid mnemonic: {e}"))?;
|
let mnemonic: bip39::Mnemonic = words
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||||
let entropy = mnemonic.to_entropy();
|
let entropy = mnemonic.to_entropy();
|
||||||
if entropy.len() != 32 {
|
if entropy.len() != 32 {
|
||||||
return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));
|
return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ pub mod session;
|
|||||||
|
|
||||||
pub use anti_replay::AntiReplayWindow;
|
pub use anti_replay::AntiReplayWindow;
|
||||||
pub use handshake::WarzoneKeyExchange;
|
pub use handshake::WarzoneKeyExchange;
|
||||||
pub use identity::{hash_room_name, Fingerprint, IdentityKeyPair, PublicIdentity, Seed};
|
pub use identity::{Fingerprint, IdentityKeyPair, PublicIdentity, Seed, hash_room_name};
|
||||||
pub use nonce::{build_nonce, Direction};
|
pub use nonce::{Direction, build_nonce};
|
||||||
pub use rekey::RekeyManager;
|
pub use rekey::RekeyManager;
|
||||||
pub use session::ChaChaSession;
|
pub use session::ChaChaSession;
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ impl RekeyManager {
|
|||||||
///
|
///
|
||||||
/// The old key is zeroized after the new key is derived.
|
/// The old key is zeroized after the new key is derived.
|
||||||
/// Returns the new 32-byte symmetric key.
|
/// Returns the new 32-byte symmetric key.
|
||||||
|
///
|
||||||
|
/// NOTE: Rekeying changes **only** the symmetric key material. Sequence
|
||||||
|
/// numbers and timestamps in the media framing layer (e.g. `MediaHeader`)
|
||||||
|
/// are untouched — they continue monotonically across the rekey boundary.
|
||||||
pub fn perform_rekey(
|
pub fn perform_rekey(
|
||||||
&mut self,
|
&mut self,
|
||||||
new_peer_pub: &[u8; 32],
|
new_peer_pub: &[u8; 32],
|
||||||
|
|||||||
@@ -3,12 +3,15 @@
|
|||||||
//! Implements the `CryptoSession` trait for per-call media encryption.
|
//! Implements the `CryptoSession` trait for per-call media encryption.
|
||||||
//! Nonces are derived deterministically from session_id + sequence counter + direction.
|
//! Nonces are derived deterministically from session_id + sequence counter + direction.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use chacha20poly1305::aead::Aead;
|
use chacha20poly1305::aead::Aead;
|
||||||
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce};
|
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce};
|
||||||
use x25519_dalek::{PublicKey, StaticSecret};
|
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
use wzp_proto::{CryptoError, CryptoSession};
|
use wzp_proto::{CryptoError, CryptoSession, MediaHeader, MediaType};
|
||||||
|
use x25519_dalek::{PublicKey, StaticSecret};
|
||||||
|
|
||||||
|
use crate::anti_replay::AntiReplayWindow;
|
||||||
use crate::nonce::{self, Direction};
|
use crate::nonce::{self, Direction};
|
||||||
use crate::rekey::RekeyManager;
|
use crate::rekey::RekeyManager;
|
||||||
|
|
||||||
@@ -28,6 +31,10 @@ pub struct ChaChaSession {
|
|||||||
pending_rekey_secret: Option<StaticSecret>,
|
pending_rekey_secret: Option<StaticSecret>,
|
||||||
/// Short Authentication String (4-digit code for verbal verification).
|
/// Short Authentication String (4-digit code for verbal verification).
|
||||||
sas_code: Option<u32>,
|
sas_code: Option<u32>,
|
||||||
|
/// Per-stream anti-replay windows, keyed by (stream_id, media_type).
|
||||||
|
anti_replay: HashMap<(u8, MediaType), AntiReplayWindow>,
|
||||||
|
/// Last timestamp seen in encrypt() — used to assert monotonicity across rekeys.
|
||||||
|
last_encrypt_timestamp: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChaChaSession {
|
impl ChaChaSession {
|
||||||
@@ -49,6 +56,8 @@ impl ChaChaSession {
|
|||||||
rekey_mgr: RekeyManager::new(shared_secret),
|
rekey_mgr: RekeyManager::new(shared_secret),
|
||||||
pending_rekey_secret: None,
|
pending_rekey_secret: None,
|
||||||
sas_code: None,
|
sas_code: None,
|
||||||
|
anti_replay: HashMap::new(),
|
||||||
|
last_encrypt_timestamp: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,6 +76,27 @@ impl ChaChaSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse a v2 `MediaHeader` from raw bytes.
|
||||||
|
/// Returns `None` if the buffer is too short or not a valid v2 header.
|
||||||
|
fn parse_header(header_bytes: &[u8]) -> Option<MediaHeader> {
|
||||||
|
if header_bytes.len() < MediaHeader::WIRE_SIZE {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mut cursor = std::io::Cursor::new(header_bytes);
|
||||||
|
MediaHeader::read_from(&mut cursor)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the default anti-replay window size for a given media type.
|
||||||
|
fn default_window_for_media_type(media_type: MediaType) -> AntiReplayWindow {
|
||||||
|
let size = match media_type {
|
||||||
|
MediaType::Audio => 64,
|
||||||
|
MediaType::Video => 1024,
|
||||||
|
MediaType::Data => 256,
|
||||||
|
MediaType::Control => 32,
|
||||||
|
};
|
||||||
|
AntiReplayWindow::with_window(size)
|
||||||
|
}
|
||||||
|
|
||||||
impl CryptoSession for ChaChaSession {
|
impl CryptoSession for ChaChaSession {
|
||||||
fn encrypt(
|
fn encrypt(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -74,10 +104,14 @@ impl CryptoSession for ChaChaSession {
|
|||||||
plaintext: &[u8],
|
plaintext: &[u8],
|
||||||
out: &mut Vec<u8>,
|
out: &mut Vec<u8>,
|
||||||
) -> Result<(), CryptoError> {
|
) -> Result<(), CryptoError> {
|
||||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.send_seq, Direction::Send);
|
// Derive nonce from the wire-level seq in the header, not from an
|
||||||
|
// internal counter. This ensures the receiver can reconstruct the
|
||||||
|
// same nonce using the header it receives, regardless of delivery order.
|
||||||
|
let header = parse_header(header_bytes)
|
||||||
|
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||||
|
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
// Encrypt with AAD
|
|
||||||
use chacha20poly1305::aead::Payload;
|
use chacha20poly1305::aead::Payload;
|
||||||
let payload = Payload {
|
let payload = Payload {
|
||||||
msg: plaintext,
|
msg: plaintext,
|
||||||
@@ -90,7 +124,19 @@ impl CryptoSession for ChaChaSession {
|
|||||||
.map_err(|_| CryptoError::Internal("encryption failed".into()))?;
|
.map_err(|_| CryptoError::Internal("encryption failed".into()))?;
|
||||||
|
|
||||||
out.extend_from_slice(&ciphertext);
|
out.extend_from_slice(&ciphertext);
|
||||||
self.send_seq = self.send_seq.wrapping_add(1);
|
self.send_seq = self.send_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||||
|
|
||||||
|
// M5: assert timestamp_ms is non-decreasing across calls (including post-rekey).
|
||||||
|
// Timestamps are u32 and wrap at 2^32 ms (~49 days); allow wrapping.
|
||||||
|
debug_assert!(
|
||||||
|
self.last_encrypt_timestamp
|
||||||
|
.map_or(true, |last| header.timestamp.wrapping_sub(last) < u32::MAX / 2),
|
||||||
|
"encrypt: timestamp must not decrease (last={:?}, now={})",
|
||||||
|
self.last_encrypt_timestamp,
|
||||||
|
header.timestamp,
|
||||||
|
);
|
||||||
|
self.last_encrypt_timestamp = Some(header.timestamp);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,9 +146,14 @@ impl CryptoSession for ChaChaSession {
|
|||||||
ciphertext: &[u8],
|
ciphertext: &[u8],
|
||||||
out: &mut Vec<u8>,
|
out: &mut Vec<u8>,
|
||||||
) -> Result<(), CryptoError> {
|
) -> Result<(), CryptoError> {
|
||||||
// Use Direction::Send to match the sender's nonce construction.
|
// Parse header before decryption — needed for nonce derivation.
|
||||||
// The recv_seq counter tracks which packet from the peer we're decrypting.
|
// Using header.seq (not recv_seq) means the nonce is always derived
|
||||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.recv_seq, Direction::Send);
|
// from the same wire field as the sender, surviving out-of-order delivery.
|
||||||
|
// A recv_seq counter diverges from the sender's send_seq on any reorder,
|
||||||
|
// causing every subsequent decryption to fail for the rest of the session.
|
||||||
|
let header = parse_header(header_bytes)
|
||||||
|
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||||
|
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
use chacha20poly1305::aead::Payload;
|
use chacha20poly1305::aead::Payload;
|
||||||
@@ -116,8 +167,21 @@ impl CryptoSession for ChaChaSession {
|
|||||||
.decrypt(nonce, payload)
|
.decrypt(nonce, payload)
|
||||||
.map_err(|_| CryptoError::DecryptionFailed)?;
|
.map_err(|_| CryptoError::DecryptionFailed)?;
|
||||||
|
|
||||||
|
let plaintext_len = plaintext.len();
|
||||||
out.extend_from_slice(&plaintext);
|
out.extend_from_slice(&plaintext);
|
||||||
self.recv_seq = self.recv_seq.wrapping_add(1);
|
self.recv_seq = self.recv_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||||
|
|
||||||
|
// Anti-replay check: header already parsed above.
|
||||||
|
let window = self
|
||||||
|
.anti_replay
|
||||||
|
.entry((header.stream_id, header.media_type))
|
||||||
|
.or_insert_with(|| default_window_for_media_type(header.media_type));
|
||||||
|
if let Err(e) = window.check_and_update(header.seq) {
|
||||||
|
// Roll back the plaintext we just appended.
|
||||||
|
out.truncate(out.len() - plaintext_len);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,10 +199,14 @@ impl CryptoSession for ChaChaSession {
|
|||||||
.ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?;
|
.ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?;
|
||||||
|
|
||||||
let total_packets = self.send_seq as u64 + self.recv_seq as u64;
|
let total_packets = self.send_seq as u64 + self.recv_seq as u64;
|
||||||
let new_key = self.rekey_mgr.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
let new_key = self
|
||||||
|
.rekey_mgr
|
||||||
|
.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
||||||
self.install_key(new_key);
|
self.install_key(new_key);
|
||||||
|
|
||||||
// Reset sequence counters after rekey for nonce uniqueness
|
// Reset sequence counters after rekey for nonce uniqueness.
|
||||||
|
// last_encrypt_timestamp is intentionally NOT reset — spec requires
|
||||||
|
// timestamp_ms to be monotonic across rekeys.
|
||||||
self.send_seq = 0;
|
self.send_seq = 0;
|
||||||
self.recv_seq = 0;
|
self.recv_seq = 0;
|
||||||
|
|
||||||
@@ -153,24 +221,42 @@ impl CryptoSession for ChaChaSession {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
|
||||||
fn make_session_pair() -> (ChaChaSession, ChaChaSession) {
|
fn make_session_pair() -> (ChaChaSession, ChaChaSession) {
|
||||||
let key = [0x42u8; 32];
|
let key = [0x42u8; 32];
|
||||||
(ChaChaSession::new(key), ChaChaSession::new(key))
|
(ChaChaSession::new(key), ChaChaSession::new(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a minimal valid v2 MediaHeader serialised to bytes.
|
||||||
|
fn make_header_bytes(seq: u32) -> Vec<u8> {
|
||||||
|
let header = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq,
|
||||||
|
timestamp: seq.wrapping_mul(20),
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut bytes = Vec::new();
|
||||||
|
header.write_to(&mut bytes);
|
||||||
|
bytes
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_decrypt_roundtrip() {
|
fn encrypt_decrypt_roundtrip() {
|
||||||
let (mut alice, mut bob) = make_session_pair();
|
let (mut alice, mut bob) = make_session_pair();
|
||||||
let header = b"test-header";
|
let header = make_header_bytes(0);
|
||||||
let plaintext = b"hello warzone";
|
let plaintext = b"hello warzone";
|
||||||
|
|
||||||
let mut ciphertext = Vec::new();
|
let mut ciphertext = Vec::new();
|
||||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||||
|
|
||||||
// Bob decrypts (his recv matches Alice's send)
|
|
||||||
let mut decrypted = Vec::new();
|
let mut decrypted = Vec::new();
|
||||||
bob.decrypt(header, &ciphertext, &mut decrypted).unwrap();
|
bob.decrypt(&header, &ciphertext, &mut decrypted).unwrap();
|
||||||
|
|
||||||
assert_eq!(&decrypted, plaintext);
|
assert_eq!(&decrypted, plaintext);
|
||||||
}
|
}
|
||||||
@@ -178,14 +264,18 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn decrypt_wrong_aad_fails() {
|
fn decrypt_wrong_aad_fails() {
|
||||||
let (mut alice, mut bob) = make_session_pair();
|
let (mut alice, mut bob) = make_session_pair();
|
||||||
let header = b"correct-header";
|
let correct_header = make_header_bytes(0);
|
||||||
|
// Different seq → different nonce AND different AAD bytes: decryption must fail.
|
||||||
|
let wrong_header = make_header_bytes(1);
|
||||||
let plaintext = b"secret data";
|
let plaintext = b"secret data";
|
||||||
|
|
||||||
let mut ciphertext = Vec::new();
|
let mut ciphertext = Vec::new();
|
||||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
alice
|
||||||
|
.encrypt(&correct_header, plaintext, &mut ciphertext)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut decrypted = Vec::new();
|
let mut decrypted = Vec::new();
|
||||||
let result = bob.decrypt(b"wrong-header", &ciphertext, &mut decrypted);
|
let result = bob.decrypt(&wrong_header, &ciphertext, &mut decrypted);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,29 +284,29 @@ mod tests {
|
|||||||
let mut alice = ChaChaSession::new([0xAA; 32]);
|
let mut alice = ChaChaSession::new([0xAA; 32]);
|
||||||
let mut eve = ChaChaSession::new([0xBB; 32]);
|
let mut eve = ChaChaSession::new([0xBB; 32]);
|
||||||
|
|
||||||
let header = b"hdr";
|
let header = make_header_bytes(0);
|
||||||
let plaintext = b"secret";
|
let plaintext = b"secret";
|
||||||
|
|
||||||
let mut ciphertext = Vec::new();
|
let mut ciphertext = Vec::new();
|
||||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||||
|
|
||||||
let mut decrypted = Vec::new();
|
let mut decrypted = Vec::new();
|
||||||
let result = eve.decrypt(header, &ciphertext, &mut decrypted);
|
let result = eve.decrypt(&header, &ciphertext, &mut decrypted);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiple_packets_roundtrip() {
|
fn multiple_packets_roundtrip() {
|
||||||
let (mut alice, mut bob) = make_session_pair();
|
let (mut alice, mut bob) = make_session_pair();
|
||||||
let header = b"hdr";
|
|
||||||
|
|
||||||
for i in 0..100 {
|
for i in 0..100u32 {
|
||||||
|
let header = make_header_bytes(i);
|
||||||
let msg = format!("message {}", i);
|
let msg = format!("message {}", i);
|
||||||
let mut ct = Vec::new();
|
let mut ct = Vec::new();
|
||||||
alice.encrypt(header, msg.as_bytes(), &mut ct).unwrap();
|
alice.encrypt(&header, msg.as_bytes(), &mut ct).unwrap();
|
||||||
|
|
||||||
let mut pt = Vec::new();
|
let mut pt = Vec::new();
|
||||||
bob.decrypt(header, &ct, &mut pt).unwrap();
|
bob.decrypt(&header, &ct, &mut pt).unwrap();
|
||||||
assert_eq!(pt, msg.as_bytes());
|
assert_eq!(pt, msg.as_bytes());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -235,4 +325,140 @@ mod tests {
|
|||||||
// Session is now rekeyed - counters reset
|
// Session is now rekeyed - counters reset
|
||||||
assert_eq!(alice.send_seq, 0);
|
assert_eq!(alice.send_seq, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_survives_out_of_order_delivery() {
|
||||||
|
// Regression test for nonce derivation using recv_seq instead of
|
||||||
|
// MediaHeader.seq. If nonces are tied to a local counter, any reorder
|
||||||
|
// causes the counter to diverge from the sender's seq and every
|
||||||
|
// subsequent packet fails decryption permanently.
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
|
||||||
|
let key = [0x55u8; 32];
|
||||||
|
let mut alice = ChaChaSession::new(key);
|
||||||
|
let mut bob = ChaChaSession::new(key);
|
||||||
|
|
||||||
|
let plaintext = b"audio payload";
|
||||||
|
|
||||||
|
// Encrypt 5 packets in order (seqs 10, 11, 12, 13, 14).
|
||||||
|
let seqs = [10u32, 11, 12, 13, 14];
|
||||||
|
let mut ciphertexts: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
|
||||||
|
for &seq in &seqs {
|
||||||
|
let header = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq,
|
||||||
|
timestamp: seq * 20,
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
header.write_to(&mut header_bytes);
|
||||||
|
let mut ct = Vec::new();
|
||||||
|
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||||
|
ciphertexts.push((header_bytes, ct));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob receives them out of order: 0, 2, 1, 4, 3
|
||||||
|
let delivery_order = [0usize, 2, 1, 4, 3];
|
||||||
|
for &idx in &delivery_order {
|
||||||
|
let (ref hdr, ref ct) = ciphertexts[idx];
|
||||||
|
let mut pt = Vec::new();
|
||||||
|
let result = bob.decrypt(hdr, ct, &mut pt);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"out-of-order packet (original idx={idx}, seq={}) must decrypt successfully",
|
||||||
|
seqs[idx]
|
||||||
|
);
|
||||||
|
assert_eq!(&pt, plaintext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn per_stream_anti_replay_rejects_duplicate() {
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
|
||||||
|
let (mut alice, mut bob) = make_session_pair();
|
||||||
|
let header = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 10,
|
||||||
|
seq: 42,
|
||||||
|
timestamp: 1000,
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
header.write_to(&mut header_bytes);
|
||||||
|
|
||||||
|
let plaintext = b"audio frame";
|
||||||
|
|
||||||
|
// First packet decrypts successfully
|
||||||
|
let mut ct = Vec::new();
|
||||||
|
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||||
|
let mut pt = Vec::new();
|
||||||
|
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||||
|
assert_eq!(&pt, plaintext);
|
||||||
|
|
||||||
|
// Exact duplicate is rejected by anti-replay
|
||||||
|
let mut pt2 = Vec::new();
|
||||||
|
let result = bob.decrypt(&header_bytes, &ct, &mut pt2);
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"duplicate packet with same seq must be rejected"
|
||||||
|
);
|
||||||
|
assert!(pt2.is_empty(), "plaintext must be rolled back on replay");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn per_stream_anti_replay_video_burst_200_with_reorder() {
|
||||||
|
use wzp_proto::{CodecId, MediaType};
|
||||||
|
|
||||||
|
let (mut alice, mut bob) = make_session_pair();
|
||||||
|
let header = MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Video,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 1,
|
||||||
|
fec_ratio: 10,
|
||||||
|
seq: 0,
|
||||||
|
timestamp: 0,
|
||||||
|
fec_block: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let plaintext = b"video frame";
|
||||||
|
|
||||||
|
// Send 200 packets in order
|
||||||
|
for i in 0..200 {
|
||||||
|
let mut h = header;
|
||||||
|
h.seq = i;
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
h.write_to(&mut header_bytes);
|
||||||
|
|
||||||
|
let mut ct = Vec::new();
|
||||||
|
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||||
|
|
||||||
|
let mut pt = Vec::new();
|
||||||
|
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-send packet 50 — should be rejected as replay
|
||||||
|
let mut h = header;
|
||||||
|
h.seq = 50;
|
||||||
|
let mut header_bytes = Vec::new();
|
||||||
|
h.write_to(&mut header_bytes);
|
||||||
|
|
||||||
|
let mut ct = Vec::new();
|
||||||
|
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||||
|
|
||||||
|
let mut pt = Vec::new();
|
||||||
|
let result = bob.decrypt(&header_bytes, &ct, &mut pt);
|
||||||
|
assert!(result.is_err(), "reordered duplicate must be rejected");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
|
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
|
||||||
//! 4. Mnemonic: BIP39 interop between both implementations
|
//! 4. Mnemonic: BIP39 interop between both implementations
|
||||||
|
|
||||||
use wzp_proto::KeyExchange;
|
use wzp_proto::{KeyExchange, default_signal_version};
|
||||||
|
|
||||||
// ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
|
// ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
|
||||||
|
|
||||||
@@ -52,7 +52,10 @@ fn wzp_identity_module_matches_featherchat() {
|
|||||||
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
|
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
|
||||||
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
|
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
|
||||||
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
|
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
|
||||||
assert_eq!(wzp_pub.fingerprint.to_string(), fc_pub.fingerprint.to_string());
|
assert_eq!(
|
||||||
|
wzp_pub.fingerprint.to_string(),
|
||||||
|
fc_pub.fingerprint.to_string()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -111,11 +114,15 @@ fn mnemonic_strings_identical() {
|
|||||||
fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||||
// WZP creates a CallOffer SignalMessage
|
// WZP creates a CallOffer SignalMessage
|
||||||
let offer = wzp_proto::SignalMessage::CallOffer {
|
let offer = wzp_proto::SignalMessage::CallOffer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub: [1u8; 32],
|
identity_pub: [1u8; 32],
|
||||||
ephemeral_pub: [2u8; 32],
|
ephemeral_pub: [2u8; 32],
|
||||||
signature: vec![3u8; 64],
|
signature: vec![3u8; 64],
|
||||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||||
alias: None,
|
alias: None,
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Encode as featherChat CallSignal payload
|
// Encode as featherChat CallSignal payload
|
||||||
@@ -148,16 +155,25 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
|||||||
// And deserializes back
|
// And deserializes back
|
||||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
|
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
|
||||||
if let warzone_protocol::message::WireMessage::CallSignal {
|
if let warzone_protocol::message::WireMessage::CallSignal {
|
||||||
id, payload: p, signal_type, ..
|
id,
|
||||||
|
payload: p,
|
||||||
|
signal_type,
|
||||||
|
..
|
||||||
} = decoded
|
} = decoded
|
||||||
{
|
{
|
||||||
assert_eq!(id, "call-123");
|
assert_eq!(id, "call-123");
|
||||||
assert!(matches!(signal_type, warzone_protocol::message::CallSignalType::Offer));
|
assert!(matches!(
|
||||||
|
signal_type,
|
||||||
|
warzone_protocol::message::CallSignalType::Offer
|
||||||
|
));
|
||||||
|
|
||||||
// Decode the WZP payload back
|
// Decode the WZP payload back
|
||||||
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
|
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
|
||||||
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
|
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
|
||||||
assert!(matches!(wzp_payload.signal, wzp_proto::SignalMessage::CallOffer { .. }));
|
assert!(matches!(
|
||||||
|
wzp_payload.signal,
|
||||||
|
wzp_proto::SignalMessage::CallOffer { .. }
|
||||||
|
));
|
||||||
} else {
|
} else {
|
||||||
panic!("expected CallSignal");
|
panic!("expected CallSignal");
|
||||||
}
|
}
|
||||||
@@ -166,10 +182,12 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn wzp_answer_round_trips_through_fc_callsignal() {
|
fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||||
let answer = wzp_proto::SignalMessage::CallAnswer {
|
let answer = wzp_proto::SignalMessage::CallAnswer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub: [10u8; 32],
|
identity_pub: [10u8; 32],
|
||||||
ephemeral_pub: [20u8; 32],
|
ephemeral_pub: [20u8; 32],
|
||||||
signature: vec![30u8; 64],
|
signature: vec![30u8; 64],
|
||||||
chosen_profile: wzp_proto::QualityProfile::DEGRADED,
|
chosen_profile: wzp_proto::QualityProfile::DEGRADED,
|
||||||
|
video_codec: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None);
|
let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None);
|
||||||
@@ -198,13 +216,17 @@ fn wzp_answer_round_trips_through_fc_callsignal() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn wzp_hangup_round_trips_through_fc_callsignal() {
|
fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
|
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
|
||||||
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
|
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
|
||||||
assert!(matches!(signal_type, wzp_client::featherchat::CallSignalType::Hangup));
|
assert!(matches!(
|
||||||
|
signal_type,
|
||||||
|
wzp_client::featherchat::CallSignalType::Hangup
|
||||||
|
));
|
||||||
|
|
||||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||||
id: "call-789".to_string(),
|
id: "call-789".to_string(),
|
||||||
@@ -219,7 +241,10 @@ fn wzp_hangup_round_trips_through_fc_callsignal() {
|
|||||||
|
|
||||||
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
||||||
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
||||||
assert!(matches!(wzp.signal, wzp_proto::SignalMessage::Hangup { .. }));
|
assert!(matches!(
|
||||||
|
wzp.signal,
|
||||||
|
wzp_proto::SignalMessage::Hangup { .. }
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -252,8 +277,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
|||||||
"eth_address": null
|
"eth_address": null
|
||||||
});
|
});
|
||||||
|
|
||||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||||
serde_json::from_value(fc_response).unwrap();
|
|
||||||
assert!(wzp_resp.valid);
|
assert!(wzp_resp.valid);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
wzp_resp.fingerprint.unwrap(),
|
wzp_resp.fingerprint.unwrap(),
|
||||||
@@ -265,8 +289,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn auth_invalid_response_matches() {
|
fn auth_invalid_response_matches() {
|
||||||
let fc_response = serde_json::json!({ "valid": false });
|
let fc_response = serde_json::json!({ "valid": false });
|
||||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||||
serde_json::from_value(fc_response).unwrap();
|
|
||||||
assert!(!wzp_resp.valid);
|
assert!(!wzp_resp.valid);
|
||||||
assert!(wzp_resp.fingerprint.is_none());
|
assert!(wzp_resp.fingerprint.is_none());
|
||||||
}
|
}
|
||||||
@@ -280,28 +303,39 @@ fn all_signal_types_map_correctly() {
|
|||||||
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
|
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
|
||||||
(
|
(
|
||||||
wzp_proto::SignalMessage::CallOffer {
|
wzp_proto::SignalMessage::CallOffer {
|
||||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
version: default_signal_version(),
|
||||||
signature: vec![], supported_profiles: vec![],
|
identity_pub: [0; 32],
|
||||||
|
ephemeral_pub: [0; 32],
|
||||||
|
signature: vec![],
|
||||||
|
supported_profiles: vec![],
|
||||||
alias: None,
|
alias: None,
|
||||||
|
protocol_version: 2,
|
||||||
|
supported_versions: vec![2],
|
||||||
|
video_codecs: vec![],
|
||||||
},
|
},
|
||||||
"Offer",
|
"Offer",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
wzp_proto::SignalMessage::CallAnswer {
|
wzp_proto::SignalMessage::CallAnswer {
|
||||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
version: default_signal_version(),
|
||||||
|
identity_pub: [0; 32],
|
||||||
|
ephemeral_pub: [0; 32],
|
||||||
signature: vec![],
|
signature: vec![],
|
||||||
chosen_profile: wzp_proto::QualityProfile::GOOD,
|
chosen_profile: wzp_proto::QualityProfile::GOOD,
|
||||||
|
video_codec: None,
|
||||||
},
|
},
|
||||||
"Answer",
|
"Answer",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
wzp_proto::SignalMessage::IceCandidate {
|
wzp_proto::SignalMessage::IceCandidate {
|
||||||
|
version: default_signal_version(),
|
||||||
candidate: "candidate:1".to_string(),
|
candidate: "candidate:1".to_string(),
|
||||||
},
|
},
|
||||||
"IceCandidate",
|
"IceCandidate",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
wzp_proto::SignalMessage::Hangup {
|
wzp_proto::SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
reason: wzp_proto::HangupReason::Normal,
|
reason: wzp_proto::HangupReason::Normal,
|
||||||
call_id: None,
|
call_id: None,
|
||||||
},
|
},
|
||||||
@@ -312,7 +346,10 @@ fn all_signal_types_map_correctly() {
|
|||||||
for (signal, expected_name) in cases {
|
for (signal, expected_name) in cases {
|
||||||
let ct = signal_to_call_type(&signal);
|
let ct = signal_to_call_type(&signal);
|
||||||
let name = format!("{ct:?}");
|
let name = format!("{ct:?}");
|
||||||
assert_eq!(name, expected_name, "signal type mapping for {expected_name}");
|
assert_eq!(
|
||||||
|
name, expected_name,
|
||||||
|
"signal type mapping for {expected_name}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,8 +463,7 @@ fn auth_response_with_eth_address() {
|
|||||||
"alias": "vitalik",
|
"alias": "vitalik",
|
||||||
"eth_address": "0x1234567890abcdef1234567890abcdef12345678"
|
"eth_address": "0x1234567890abcdef1234567890abcdef12345678"
|
||||||
});
|
});
|
||||||
let resp: wzp_relay::auth::ValidateResponse =
|
let resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_eth).unwrap();
|
||||||
serde_json::from_value(with_eth).unwrap();
|
|
||||||
assert!(resp.valid);
|
assert!(resp.valid);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resp.fingerprint.unwrap(),
|
resp.fingerprint.unwrap(),
|
||||||
@@ -442,8 +478,7 @@ fn auth_response_with_eth_address() {
|
|||||||
"alias": "anon",
|
"alias": "anon",
|
||||||
"eth_address": null
|
"eth_address": null
|
||||||
});
|
});
|
||||||
let resp2: wzp_relay::auth::ValidateResponse =
|
let resp2: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_null_eth).unwrap();
|
||||||
serde_json::from_value(with_null_eth).unwrap();
|
|
||||||
assert!(resp2.valid);
|
assert!(resp2.valid);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resp2.fingerprint.unwrap(),
|
resp2.fingerprint.unwrap(),
|
||||||
@@ -454,15 +489,15 @@ fn auth_response_with_eth_address() {
|
|||||||
let without_eth = serde_json::json!({
|
let without_eth = serde_json::json!({
|
||||||
"valid": false
|
"valid": false
|
||||||
});
|
});
|
||||||
let resp3: wzp_relay::auth::ValidateResponse =
|
let resp3: wzp_relay::auth::ValidateResponse = serde_json::from_value(without_eth).unwrap();
|
||||||
serde_json::from_value(without_eth).unwrap();
|
|
||||||
assert!(!resp3.valid);
|
assert!(!resp3.valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// WZP-S-7: SignalMessage::AuthToken { token } exists and round-trips via serde.
|
/// WZP-S-7: SignalMessage::AuthToken { version: default_signal_version(), token } exists and round-trips via serde.
|
||||||
#[test]
|
#[test]
|
||||||
fn wzp_proto_has_auth_token_variant() {
|
fn wzp_proto_has_auth_token_variant() {
|
||||||
let msg = wzp_proto::SignalMessage::AuthToken {
|
let msg = wzp_proto::SignalMessage::AuthToken {
|
||||||
|
version: default_signal_version(),
|
||||||
token: "fc-bearer-token-xyz".to_string(),
|
token: "fc-bearer-token-xyz".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -473,7 +508,7 @@ fn wzp_proto_has_auth_token_variant() {
|
|||||||
|
|
||||||
// Deserialize back
|
// Deserialize back
|
||||||
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
|
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
|
||||||
if let wzp_proto::SignalMessage::AuthToken { token } = decoded {
|
if let wzp_proto::SignalMessage::AuthToken { token, .. } = decoded {
|
||||||
assert_eq!(token, "fc-bearer-token-xyz");
|
assert_eq!(token, "fc-bearer-token-xyz");
|
||||||
} else {
|
} else {
|
||||||
panic!("expected AuthToken variant, got: {decoded:?}");
|
panic!("expected AuthToken variant, got: {decoded:?}");
|
||||||
@@ -496,7 +531,11 @@ fn all_fc_call_signal_types_representable() {
|
|||||||
(CallSignalType::Busy, "Busy"),
|
(CallSignalType::Busy, "Busy"),
|
||||||
];
|
];
|
||||||
|
|
||||||
assert_eq!(variants.len(), 7, "featherChat defines exactly 7 call signal types");
|
assert_eq!(
|
||||||
|
variants.len(),
|
||||||
|
7,
|
||||||
|
"featherChat defines exactly 7 call signal types"
|
||||||
|
);
|
||||||
|
|
||||||
for (variant, expected_name) in &variants {
|
for (variant, expected_name) in &variants {
|
||||||
let name = format!("{variant:?}");
|
let name = format!("{variant:?}");
|
||||||
@@ -550,10 +589,7 @@ fn hash_room_name_used_as_sni_is_valid() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn wzp_proto_cargo_toml_is_standalone() {
|
fn wzp_proto_cargo_toml_is_standalone() {
|
||||||
// Try both paths (run from workspace root or from crate directory)
|
// Try both paths (run from workspace root or from crate directory)
|
||||||
let candidates = [
|
let candidates = ["crates/wzp-proto/Cargo.toml", "../wzp-proto/Cargo.toml"];
|
||||||
"crates/wzp-proto/Cargo.toml",
|
|
||||||
"../wzp-proto/Cargo.toml",
|
|
||||||
];
|
|
||||||
|
|
||||||
let contents = candidates
|
let contents = candidates
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -13,11 +13,17 @@ pub struct AdaptiveFec {
|
|||||||
pub repair_ratio: f32,
|
pub repair_ratio: f32,
|
||||||
/// Symbol size in bytes.
|
/// Symbol size in bytes.
|
||||||
pub symbol_size: u16,
|
pub symbol_size: u16,
|
||||||
|
/// Repair ratio to use when the block contains a keyframe.
|
||||||
|
/// Default 0.5 (50% overhead) — keyframes are critical and worth
|
||||||
|
/// the extra bandwidth.
|
||||||
|
pub keyframe_repair_ratio: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AdaptiveFec {
|
impl AdaptiveFec {
|
||||||
/// Default symbol size for adaptive configuration.
|
/// Default symbol size for adaptive configuration.
|
||||||
const DEFAULT_SYMBOL_SIZE: u16 = 256;
|
const DEFAULT_SYMBOL_SIZE: u16 = 256;
|
||||||
|
/// Default keyframe repair ratio (PRD-video-v1 T4.5).
|
||||||
|
const DEFAULT_KEYFRAME_REPAIR_RATIO: f32 = 0.5;
|
||||||
|
|
||||||
/// Create an adaptive FEC configuration from a quality profile.
|
/// Create an adaptive FEC configuration from a quality profile.
|
||||||
///
|
///
|
||||||
@@ -30,12 +36,15 @@ impl AdaptiveFec {
|
|||||||
frames_per_block: profile.frames_per_block as usize,
|
frames_per_block: profile.frames_per_block as usize,
|
||||||
repair_ratio: profile.fec_ratio,
|
repair_ratio: profile.fec_ratio,
|
||||||
symbol_size: Self::DEFAULT_SYMBOL_SIZE,
|
symbol_size: Self::DEFAULT_SYMBOL_SIZE,
|
||||||
|
keyframe_repair_ratio: Self::DEFAULT_KEYFRAME_REPAIR_RATIO,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a configured FEC encoder from this adaptive configuration.
|
/// Build a configured FEC encoder from this adaptive configuration.
|
||||||
pub fn build_encoder(&self) -> RaptorQFecEncoder {
|
pub fn build_encoder(&self) -> RaptorQFecEncoder {
|
||||||
RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size)
|
let mut enc = RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size);
|
||||||
|
enc.set_keyframe_ratio(self.keyframe_repair_ratio);
|
||||||
|
enc
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the repair ratio for use with `FecEncoder::generate_repair()`.
|
/// Get the repair ratio for use with `FecEncoder::generate_repair()`.
|
||||||
@@ -59,6 +68,7 @@ mod tests {
|
|||||||
let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD);
|
let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD);
|
||||||
assert_eq!(cfg.frames_per_block, 5);
|
assert_eq!(cfg.frames_per_block, 5);
|
||||||
assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON);
|
assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON);
|
||||||
|
assert!((cfg.keyframe_repair_ratio - 0.5).abs() < f32::EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ pub enum DecoderBlockState {
|
|||||||
/// Manages encoder-side block tracking.
|
/// Manages encoder-side block tracking.
|
||||||
pub struct EncoderBlockManager {
|
pub struct EncoderBlockManager {
|
||||||
/// Current block ID being built.
|
/// Current block ID being built.
|
||||||
current_id: u8,
|
current_id: u16,
|
||||||
/// State of known blocks.
|
/// State of known blocks.
|
||||||
blocks: HashMap<u8, EncoderBlockState>,
|
blocks: HashMap<u16, EncoderBlockState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncoderBlockManager {
|
impl EncoderBlockManager {
|
||||||
@@ -45,7 +45,7 @@ impl EncoderBlockManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the next block ID (advances the current building block).
|
/// Get the next block ID (advances the current building block).
|
||||||
pub fn next_block_id(&mut self) -> u8 {
|
pub fn next_block_id(&mut self) -> u16 {
|
||||||
let old = self.current_id;
|
let old = self.current_id;
|
||||||
// Mark old block as pending.
|
// Mark old block as pending.
|
||||||
self.blocks.insert(old, EncoderBlockState::Pending);
|
self.blocks.insert(old, EncoderBlockState::Pending);
|
||||||
@@ -57,23 +57,23 @@ impl EncoderBlockManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Current block ID being built.
|
/// Current block ID being built.
|
||||||
pub fn current_id(&self) -> u8 {
|
pub fn current_id(&self) -> u16 {
|
||||||
self.current_id
|
self.current_id
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark a block as fully sent.
|
/// Mark a block as fully sent.
|
||||||
pub fn mark_sent(&mut self, block_id: u8) {
|
pub fn mark_sent(&mut self, block_id: u16) {
|
||||||
self.blocks.insert(block_id, EncoderBlockState::Sent);
|
self.blocks.insert(block_id, EncoderBlockState::Sent);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark a block as acknowledged by the peer.
|
/// Mark a block as acknowledged by the peer.
|
||||||
pub fn mark_acknowledged(&mut self, block_id: u8) {
|
pub fn mark_acknowledged(&mut self, block_id: u16) {
|
||||||
self.blocks
|
self.blocks
|
||||||
.insert(block_id, EncoderBlockState::Acknowledged);
|
.insert(block_id, EncoderBlockState::Acknowledged);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the state of a block.
|
/// Get the state of a block.
|
||||||
pub fn state(&self, block_id: u8) -> Option<EncoderBlockState> {
|
pub fn state(&self, block_id: u16) -> Option<EncoderBlockState> {
|
||||||
self.blocks.get(&block_id).copied()
|
self.blocks.get(&block_id).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,9 +93,9 @@ impl Default for EncoderBlockManager {
|
|||||||
/// Manages decoder-side block tracking.
|
/// Manages decoder-side block tracking.
|
||||||
pub struct DecoderBlockManager {
|
pub struct DecoderBlockManager {
|
||||||
/// State of known blocks.
|
/// State of known blocks.
|
||||||
blocks: HashMap<u8, DecoderBlockState>,
|
blocks: HashMap<u16, DecoderBlockState>,
|
||||||
/// Set of completed block IDs.
|
/// Set of completed block IDs.
|
||||||
completed: HashSet<u8>,
|
completed: HashSet<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderBlockManager {
|
impl DecoderBlockManager {
|
||||||
@@ -107,43 +107,43 @@ impl DecoderBlockManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register that we are receiving symbols for a block.
|
/// Register that we are receiving symbols for a block.
|
||||||
pub fn touch(&mut self, block_id: u8) {
|
pub fn touch(&mut self, block_id: u16) {
|
||||||
self.blocks
|
self.blocks
|
||||||
.entry(block_id)
|
.entry(block_id)
|
||||||
.or_insert(DecoderBlockState::Assembling);
|
.or_insert(DecoderBlockState::Assembling);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark a block as successfully decoded.
|
/// Mark a block as successfully decoded.
|
||||||
pub fn mark_complete(&mut self, block_id: u8) {
|
pub fn mark_complete(&mut self, block_id: u16) {
|
||||||
self.blocks.insert(block_id, DecoderBlockState::Complete);
|
self.blocks.insert(block_id, DecoderBlockState::Complete);
|
||||||
self.completed.insert(block_id);
|
self.completed.insert(block_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark a block as expired.
|
/// Mark a block as expired.
|
||||||
pub fn mark_expired(&mut self, block_id: u8) {
|
pub fn mark_expired(&mut self, block_id: u16) {
|
||||||
self.blocks.insert(block_id, DecoderBlockState::Expired);
|
self.blocks.insert(block_id, DecoderBlockState::Expired);
|
||||||
self.completed.remove(&block_id);
|
self.completed.remove(&block_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a block has been fully decoded.
|
/// Check if a block has been fully decoded.
|
||||||
pub fn is_block_complete(&self, block_id: u8) -> bool {
|
pub fn is_block_complete(&self, block_id: u16) -> bool {
|
||||||
self.completed.contains(&block_id)
|
self.completed.contains(&block_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the state of a block.
|
/// Get the state of a block.
|
||||||
pub fn state(&self, block_id: u8) -> Option<DecoderBlockState> {
|
pub fn state(&self, block_id: u16) -> Option<DecoderBlockState> {
|
||||||
self.blocks.get(&block_id).copied()
|
self.blocks.get(&block_id).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expire all blocks older than the given block_id (using wrapping distance).
|
/// Expire all blocks older than the given block_id (using wrapping distance).
|
||||||
pub fn expire_before(&mut self, block_id: u8) {
|
pub fn expire_before(&mut self, block_id: u16) {
|
||||||
let to_expire: Vec<u8> = self
|
let to_expire: Vec<u16> = self
|
||||||
.blocks
|
.blocks
|
||||||
.keys()
|
.keys()
|
||||||
.copied()
|
.copied()
|
||||||
.filter(|&id| {
|
.filter(|&id| {
|
||||||
let distance = block_id.wrapping_sub(id);
|
let distance = block_id.wrapping_sub(id);
|
||||||
distance > 0 && distance <= 128
|
distance > 0 && distance <= 32768
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn decoder_expire_before() {
|
fn decoder_expire_before() {
|
||||||
let mut mgr = DecoderBlockManager::new();
|
let mut mgr = DecoderBlockManager::new();
|
||||||
for i in 0..5u8 {
|
for i in 0..5u16 {
|
||||||
mgr.touch(i);
|
mgr.touch(i);
|
||||||
}
|
}
|
||||||
mgr.mark_complete(1);
|
mgr.mark_complete(1);
|
||||||
@@ -231,11 +231,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn next_block_id_wraps() {
|
fn next_block_id_wraps() {
|
||||||
let mut mgr = EncoderBlockManager::new();
|
let mut mgr = EncoderBlockManager::new();
|
||||||
// Start at 0, advance to 255 then wrap
|
// Start at 0, advance to u16::MAX then wrap
|
||||||
for _ in 0..255 {
|
for _ in 0..65535 {
|
||||||
mgr.next_block_id();
|
mgr.next_block_id();
|
||||||
}
|
}
|
||||||
assert_eq!(mgr.current_id(), 255);
|
assert_eq!(mgr.current_id(), u16::MAX);
|
||||||
let next = mgr.next_block_id();
|
let next = mgr.next_block_id();
|
||||||
assert_eq!(next, 0);
|
assert_eq!(next, 0);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ use std::collections::HashMap;
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder};
|
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder};
|
||||||
use wzp_proto::error::FecError;
|
|
||||||
use wzp_proto::FecDecoder;
|
use wzp_proto::FecDecoder;
|
||||||
|
use wzp_proto::error::FecError;
|
||||||
|
|
||||||
/// Length prefix size (u16 little-endian), must match encoder.
|
/// Length prefix size (u16 little-endian), must match encoder.
|
||||||
const LEN_PREFIX: usize = 2;
|
const LEN_PREFIX: usize = 2;
|
||||||
@@ -32,7 +32,7 @@ struct BlockState {
|
|||||||
/// RaptorQ-based FEC decoder that handles multiple concurrent blocks.
|
/// RaptorQ-based FEC decoder that handles multiple concurrent blocks.
|
||||||
pub struct RaptorQFecDecoder {
|
pub struct RaptorQFecDecoder {
|
||||||
/// Per-block decoder state, keyed by block_id.
|
/// Per-block decoder state, keyed by block_id.
|
||||||
blocks: HashMap<u8, BlockState>,
|
blocks: HashMap<u16, BlockState>,
|
||||||
/// Symbol size (must match encoder).
|
/// Symbol size (must match encoder).
|
||||||
symbol_size: u16,
|
symbol_size: u16,
|
||||||
/// Number of source symbols per block (from encoder config).
|
/// Number of source symbols per block (from encoder config).
|
||||||
@@ -57,7 +57,7 @@ impl RaptorQFecDecoder {
|
|||||||
Self::new(frames_per_block, 256)
|
Self::new(frames_per_block, 256)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_create_block(&mut self, block_id: u8) -> &mut BlockState {
|
fn get_or_create_block(&mut self, block_id: u16) -> &mut BlockState {
|
||||||
self.blocks.entry(block_id).or_insert_with(|| BlockState {
|
self.blocks.entry(block_id).or_insert_with(|| BlockState {
|
||||||
num_source_symbols: Some(self.frames_per_block),
|
num_source_symbols: Some(self.frames_per_block),
|
||||||
packets: Vec::new(),
|
packets: Vec::new(),
|
||||||
@@ -72,8 +72,8 @@ impl RaptorQFecDecoder {
|
|||||||
impl FecDecoder for RaptorQFecDecoder {
|
impl FecDecoder for RaptorQFecDecoder {
|
||||||
fn add_symbol(
|
fn add_symbol(
|
||||||
&mut self,
|
&mut self,
|
||||||
block_id: u8,
|
block_id: u16,
|
||||||
symbol_index: u8,
|
symbol_index: u16,
|
||||||
_is_repair: bool,
|
_is_repair: bool,
|
||||||
data: &[u8],
|
data: &[u8],
|
||||||
) -> Result<(), FecError> {
|
) -> Result<(), FecError> {
|
||||||
@@ -104,13 +104,13 @@ impl FecDecoder for RaptorQFecDecoder {
|
|||||||
padded[..len].copy_from_slice(&data[..len]);
|
padded[..len].copy_from_slice(&data[..len]);
|
||||||
|
|
||||||
let esi = symbol_index as u32;
|
let esi = symbol_index as u32;
|
||||||
let packet = EncodingPacket::new(PayloadId::new(block_id, esi), padded);
|
let packet = EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, esi), padded);
|
||||||
block.packets.push(packet);
|
block.packets.push(packet);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError> {
|
fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError> {
|
||||||
let frames_per_block = self.frames_per_block;
|
let frames_per_block = self.frames_per_block;
|
||||||
let block = match self.blocks.get_mut(&block_id) {
|
let block = match self.blocks.get_mut(&block_id) {
|
||||||
Some(b) => b,
|
Some(b) => b,
|
||||||
@@ -125,7 +125,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
|||||||
let block_length = (num_source as u64) * (block.symbol_size as u64);
|
let block_length = (num_source as u64) * (block.symbol_size as u64);
|
||||||
|
|
||||||
let config = ObjectTransmissionInformation::with_defaults(block_length, block.symbol_size);
|
let config = ObjectTransmissionInformation::with_defaults(block_length, block.symbol_size);
|
||||||
let mut decoder = SourceBlockDecoder::new(block_id, &config, block_length);
|
let mut decoder = SourceBlockDecoder::new((block_id & 0xFF) as u8, &config, block_length);
|
||||||
|
|
||||||
let decoded = decoder.decode(block.packets.clone());
|
let decoded = decoder.decode(block.packets.clone());
|
||||||
|
|
||||||
@@ -140,10 +140,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
|||||||
frames.push(Vec::new());
|
frames.push(Vec::new());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let payload_len = u16::from_le_bytes([
|
let payload_len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
|
||||||
data[offset],
|
|
||||||
data[offset + 1],
|
|
||||||
]) as usize;
|
|
||||||
let payload_start = offset + LEN_PREFIX;
|
let payload_start = offset + LEN_PREFIX;
|
||||||
let payload_end = (payload_start + payload_len).min(data.len());
|
let payload_end = (payload_start + payload_len).min(data.len());
|
||||||
frames.push(data[payload_start..payload_end].to_vec());
|
frames.push(data[payload_start..payload_end].to_vec());
|
||||||
@@ -159,15 +156,15 @@ impl FecDecoder for RaptorQFecDecoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expire_before(&mut self, block_id: u8) {
|
fn expire_before(&mut self, block_id: u16) {
|
||||||
// Remove blocks with IDs "older" than block_id.
|
// Remove blocks with IDs "older" than block_id.
|
||||||
// With wrapping u8 IDs, we consider a block old if its distance
|
// With wrapping u16 IDs, we consider a block old if its distance
|
||||||
// (in the forward direction) to block_id is > 128.
|
// (in the forward direction) to block_id is > 32768.
|
||||||
self.blocks.retain(|&id, _| {
|
self.blocks.retain(|&id, _| {
|
||||||
let distance = block_id.wrapping_sub(id);
|
let distance = block_id.wrapping_sub(id);
|
||||||
// If distance is 0 or > 128, the block is current or "ahead" — keep it.
|
// If distance is 0 or > 32768, the block is current or "ahead" — keep it.
|
||||||
// If distance is 1..=128, the block is behind — remove it.
|
// If distance is 1..=32768, the block is behind — remove it.
|
||||||
distance == 0 || distance > 128
|
distance == 0 || distance > 32768
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,9 +195,7 @@ mod tests {
|
|||||||
|
|
||||||
// Feed all source symbols (using the length-prefixed padded data).
|
// Feed all source symbols (using the length-prefixed padded data).
|
||||||
for (i, pkt) in source_pkts.iter().enumerate() {
|
for (i, pkt) in source_pkts.iter().enumerate() {
|
||||||
decoder
|
decoder.add_symbol(0, i as u16, false, pkt.data()).unwrap();
|
||||||
.add_symbol(0, i as u8, false, pkt.data())
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = decoder.try_decode(0).unwrap();
|
let result = decoder.try_decode(0).unwrap();
|
||||||
@@ -233,7 +228,11 @@ mod tests {
|
|||||||
let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1);
|
let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1);
|
||||||
let mut dec = SourceBlockDecoder::new(0, &config, block_len);
|
let mut dec = SourceBlockDecoder::new(0, &config, block_len);
|
||||||
let decoded = dec.decode(all);
|
let decoded = dec.decode(all);
|
||||||
assert!(decoded.is_some(), "Should recover with {:.0}% loss", drop_fraction * 100.0);
|
assert!(
|
||||||
|
decoded.is_some(),
|
||||||
|
"Should recover with {:.0}% loss",
|
||||||
|
drop_fraction * 100.0
|
||||||
|
);
|
||||||
|
|
||||||
let data = decoded.unwrap();
|
let data = decoded.unwrap();
|
||||||
let ss = SYMBOL_SIZE as usize;
|
let ss = SYMBOL_SIZE as usize;
|
||||||
@@ -245,22 +244,28 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decode_with_30pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3); }
|
fn decode_with_30pct_loss() {
|
||||||
|
run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decode_with_50pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5); }
|
fn decode_with_50pct_loss() {
|
||||||
|
run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decode_with_70pct_source_loss_heavy_repair() { run_loss_test(8, 2.0, 0.5); }
|
fn decode_with_70pct_source_loss_heavy_repair() {
|
||||||
|
run_loss_test(8, 2.0, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn expire_removes_old_blocks() {
|
fn expire_removes_old_blocks() {
|
||||||
let mut decoder = RaptorQFecDecoder::new(FRAMES_PER_BLOCK, SYMBOL_SIZE);
|
let mut decoder = RaptorQFecDecoder::new(FRAMES_PER_BLOCK, SYMBOL_SIZE);
|
||||||
|
|
||||||
// Add symbols to blocks 0, 1, 2
|
// Add symbols to blocks 0, 1, 2
|
||||||
for block_id in 0..3u8 {
|
for block_id in 0..3u16 {
|
||||||
decoder
|
decoder
|
||||||
.add_symbol(block_id, 0, false, &[block_id; 50])
|
.add_symbol(block_id, 0, false, &[block_id as u8; 50])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,10 +293,10 @@ mod tests {
|
|||||||
// Interleave symbols from block 0 and block 1
|
// Interleave symbols from block 0 and block 1
|
||||||
for i in 0..FRAMES_PER_BLOCK {
|
for i in 0..FRAMES_PER_BLOCK {
|
||||||
decoder
|
decoder
|
||||||
.add_symbol(0, i as u8, false, pkts_a[i].data())
|
.add_symbol(0, i as u16, false, pkts_a[i].data())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
decoder
|
decoder
|
||||||
.add_symbol(1, i as u8, false, pkts_b[i].data())
|
.add_symbol(1, i as u16, false, pkts_b[i].data())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
//! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols.
|
//! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols.
|
||||||
|
|
||||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder};
|
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder};
|
||||||
use wzp_proto::error::FecError;
|
|
||||||
use wzp_proto::FecEncoder;
|
use wzp_proto::FecEncoder;
|
||||||
|
use wzp_proto::error::FecError;
|
||||||
|
|
||||||
/// Maximum symbol size in bytes. Audio frames are typically < 200 bytes,
|
/// Maximum symbol size in bytes. Audio frames are typically < 200 bytes,
|
||||||
/// but we pad to a uniform size within a block.
|
/// but we pad to a uniform size within a block.
|
||||||
@@ -15,14 +15,19 @@ const LEN_PREFIX: usize = 2;
|
|||||||
/// RaptorQ-based FEC encoder that groups audio frames into blocks
|
/// RaptorQ-based FEC encoder that groups audio frames into blocks
|
||||||
/// and generates fountain-code repair symbols.
|
/// and generates fountain-code repair symbols.
|
||||||
pub struct RaptorQFecEncoder {
|
pub struct RaptorQFecEncoder {
|
||||||
/// Current block ID (wraps at u8).
|
/// Current block ID (wraps at u16).
|
||||||
block_id: u8,
|
block_id: u16,
|
||||||
/// Maximum source symbols per block.
|
/// Maximum source symbols per block.
|
||||||
frames_per_block: usize,
|
frames_per_block: usize,
|
||||||
/// Accumulated source symbols for the current block.
|
/// Accumulated source symbols for the current block.
|
||||||
source_symbols: Vec<Vec<u8>>,
|
source_symbols: Vec<Vec<u8>>,
|
||||||
/// Symbol size used for encoding (all symbols padded to this size).
|
/// Symbol size used for encoding (all symbols padded to this size).
|
||||||
symbol_size: u16,
|
symbol_size: u16,
|
||||||
|
/// True if at least one source symbol in the current block is a keyframe.
|
||||||
|
has_keyframe: bool,
|
||||||
|
/// Repair ratio to use when the block contains a keyframe.
|
||||||
|
/// If zero, the nominal ratio passed to [`generate_repair`] is used.
|
||||||
|
keyframe_ratio: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RaptorQFecEncoder {
|
impl RaptorQFecEncoder {
|
||||||
@@ -36,9 +41,26 @@ impl RaptorQFecEncoder {
|
|||||||
frames_per_block,
|
frames_per_block,
|
||||||
source_symbols: Vec::with_capacity(frames_per_block),
|
source_symbols: Vec::with_capacity(frames_per_block),
|
||||||
symbol_size,
|
symbol_size,
|
||||||
|
has_keyframe: false,
|
||||||
|
keyframe_ratio: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the repair ratio to use for blocks that contain at least one
|
||||||
|
/// keyframe source symbol.
|
||||||
|
///
|
||||||
|
/// When `keyframe_ratio > 0.0` and [`has_keyframe`](Self::has_keyframe)
|
||||||
|
/// is true, [`generate_repair`](FecEncoder::generate_repair) uses this
|
||||||
|
/// ratio instead of the nominal ratio passed by the caller.
|
||||||
|
pub fn set_keyframe_ratio(&mut self, ratio: f32) {
|
||||||
|
self.keyframe_ratio = ratio.max(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the current block contains a keyframe source symbol.
|
||||||
|
pub fn has_keyframe(&self) -> bool {
|
||||||
|
self.has_keyframe
|
||||||
|
}
|
||||||
|
|
||||||
/// Create with default symbol size (256 bytes).
|
/// Create with default symbol size (256 bytes).
|
||||||
pub fn with_defaults(frames_per_block: usize) -> Self {
|
pub fn with_defaults(frames_per_block: usize) -> Self {
|
||||||
Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE)
|
Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE)
|
||||||
@@ -54,8 +76,7 @@ impl RaptorQFecEncoder {
|
|||||||
let payload_len = sym.len().min(max_payload);
|
let payload_len = sym.len().min(max_payload);
|
||||||
let offset = i * ss;
|
let offset = i * ss;
|
||||||
// Write 2-byte little-endian length prefix.
|
// Write 2-byte little-endian length prefix.
|
||||||
data[offset..offset + LEN_PREFIX]
|
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
|
||||||
// Write payload after prefix.
|
// Write payload after prefix.
|
||||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||||
.copy_from_slice(&sym[..payload_len]);
|
.copy_from_slice(&sym[..payload_len]);
|
||||||
@@ -75,17 +96,36 @@ impl FecEncoder for RaptorQFecEncoder {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError> {
|
fn add_source_symbol_with_keyframe(
|
||||||
|
&mut self,
|
||||||
|
data: &[u8],
|
||||||
|
is_keyframe: bool,
|
||||||
|
) -> Result<(), FecError> {
|
||||||
|
self.add_source_symbol(data)?;
|
||||||
|
if is_keyframe {
|
||||||
|
self.has_keyframe = true;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError> {
|
||||||
if self.source_symbols.is_empty() {
|
if self.source_symbols.is_empty() {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let effective_ratio = if self.has_keyframe && self.keyframe_ratio > 0.0 {
|
||||||
|
self.keyframe_ratio
|
||||||
|
} else {
|
||||||
|
ratio
|
||||||
|
};
|
||||||
|
|
||||||
let block_data = self.build_block_data();
|
let block_data = self.build_block_data();
|
||||||
let config = ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
let config =
|
||||||
let encoder = SourceBlockEncoder::new(self.block_id, &config, &block_data);
|
ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
||||||
|
let encoder = SourceBlockEncoder::new((self.block_id & 0xFF) as u8, &config, &block_data);
|
||||||
|
|
||||||
let num_source = self.source_symbols.len() as u32;
|
let num_source = self.source_symbols.len() as u32;
|
||||||
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
let num_repair = ((num_source as f32) * effective_ratio).ceil() as u32;
|
||||||
if num_repair == 0 {
|
if num_repair == 0 {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
@@ -93,11 +133,11 @@ impl FecEncoder for RaptorQFecEncoder {
|
|||||||
// Generate repair packets starting from offset 0 (ESIs begin at num_source).
|
// Generate repair packets starting from offset 0 (ESIs begin at num_source).
|
||||||
let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair);
|
let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair);
|
||||||
|
|
||||||
let result: Vec<(u8, Vec<u8>)> = repair_packets
|
let result: Vec<(u16, Vec<u8>)> = repair_packets
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, pkt): (usize, EncodingPacket)| {
|
.map(|(i, pkt): (usize, EncodingPacket)| {
|
||||||
let idx = (num_source as u8).wrapping_add(i as u8);
|
let idx = (num_source as u16).wrapping_add(i as u16);
|
||||||
(idx, pkt.data().to_vec())
|
(idx, pkt.data().to_vec())
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -105,14 +145,15 @@ impl FecEncoder for RaptorQFecEncoder {
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finalize_block(&mut self) -> Result<u8, FecError> {
|
fn finalize_block(&mut self) -> Result<u16, FecError> {
|
||||||
let completed = self.block_id;
|
let completed = self.block_id;
|
||||||
self.block_id = self.block_id.wrapping_add(1);
|
self.block_id = self.block_id.wrapping_add(1);
|
||||||
self.source_symbols.clear();
|
self.source_symbols.clear();
|
||||||
|
self.has_keyframe = false;
|
||||||
Ok(completed)
|
Ok(completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn current_block_id(&self) -> u8 {
|
fn current_block_id(&self) -> u16 {
|
||||||
self.block_id
|
self.block_id
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,8 +171,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
|
|||||||
let max_payload = ss - LEN_PREFIX;
|
let max_payload = ss - LEN_PREFIX;
|
||||||
let payload_len = sym.len().min(max_payload);
|
let payload_len = sym.len().min(max_payload);
|
||||||
let offset = i * ss;
|
let offset = i * ss;
|
||||||
data[offset..offset + LEN_PREFIX]
|
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
|
||||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||||
.copy_from_slice(&sym[..payload_len]);
|
.copy_from_slice(&sym[..payload_len]);
|
||||||
}
|
}
|
||||||
@@ -141,7 +181,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
|
|||||||
/// Helper: build source `EncodingPacket`s for a given block. Useful for
|
/// Helper: build source `EncodingPacket`s for a given block. Useful for
|
||||||
/// the decoder tests and interleaving.
|
/// the decoder tests and interleaving.
|
||||||
pub fn source_packets_for_block(
|
pub fn source_packets_for_block(
|
||||||
block_id: u8,
|
block_id: u16,
|
||||||
symbols: &[Vec<u8>],
|
symbols: &[Vec<u8>],
|
||||||
symbol_size: u16,
|
symbol_size: u16,
|
||||||
) -> Vec<EncodingPacket> {
|
) -> Vec<EncodingPacket> {
|
||||||
@@ -151,21 +191,21 @@ pub fn source_packets_for_block(
|
|||||||
.map(|i| {
|
.map(|i| {
|
||||||
let offset = i * ss;
|
let offset = i * ss;
|
||||||
let sym_data = data[offset..offset + ss].to_vec();
|
let sym_data = data[offset..offset + ss].to_vec();
|
||||||
EncodingPacket::new(PayloadId::new(block_id, i as u32), sym_data)
|
EncodingPacket::new(PayloadId::new((block_id & 0xFF) as u8, i as u32), sym_data)
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper: generate repair packets for the given source symbols.
|
/// Helper: generate repair packets for the given source symbols.
|
||||||
pub fn repair_packets_for_block(
|
pub fn repair_packets_for_block(
|
||||||
block_id: u8,
|
block_id: u16,
|
||||||
symbols: &[Vec<u8>],
|
symbols: &[Vec<u8>],
|
||||||
symbol_size: u16,
|
symbol_size: u16,
|
||||||
ratio: f32,
|
ratio: f32,
|
||||||
) -> Vec<EncodingPacket> {
|
) -> Vec<EncodingPacket> {
|
||||||
let data = build_prefixed_block_data(symbols, symbol_size);
|
let data = build_prefixed_block_data(symbols, symbol_size);
|
||||||
let config = ObjectTransmissionInformation::with_defaults(data.len() as u64, symbol_size);
|
let config = ObjectTransmissionInformation::with_defaults(data.len() as u64, symbol_size);
|
||||||
let encoder = SourceBlockEncoder::new(block_id, &config, &data);
|
let encoder = SourceBlockEncoder::new((block_id & 0xFF) as u8, &config, &data);
|
||||||
let num_source = symbols.len() as u32;
|
let num_source = symbols.len() as u32;
|
||||||
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
||||||
encoder.repair_packets(0, num_repair)
|
encoder.repair_packets(0, num_repair)
|
||||||
@@ -201,14 +241,70 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn block_id_wraps() {
|
fn block_id_wraps_u16() {
|
||||||
let mut enc = RaptorQFecEncoder::with_defaults(1);
|
let mut enc = RaptorQFecEncoder::with_defaults(1);
|
||||||
for expected in 0..=255u8 {
|
// Advance 300 blocks and verify no panic + monotonic increment.
|
||||||
|
for expected in 0..300u16 {
|
||||||
assert_eq!(enc.current_block_id(), expected);
|
assert_eq!(enc.current_block_id(), expected);
|
||||||
enc.add_source_symbol(&[expected; 10]).unwrap();
|
enc.add_source_symbol(&[0u8; 10]).unwrap();
|
||||||
enc.finalize_block().unwrap();
|
enc.finalize_block().unwrap();
|
||||||
}
|
}
|
||||||
// After 256 blocks, wraps back to 0
|
// Explicitly test wrap at u16 boundary.
|
||||||
assert_eq!(enc.current_block_id(), 0);
|
let mut enc2 = RaptorQFecEncoder::with_defaults(1);
|
||||||
|
enc2.block_id = u16::MAX;
|
||||||
|
enc2.add_source_symbol(&[0u8; 10]).unwrap();
|
||||||
|
let id = enc2.finalize_block().unwrap();
|
||||||
|
assert_eq!(id, u16::MAX);
|
||||||
|
assert_eq!(enc2.current_block_id(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn keyframe_boost_uses_higher_ratio() {
|
||||||
|
// Non-keyframe block with nominal ratio 0.2 → ceil(5 * 0.2) = 1 repair.
|
||||||
|
let mut enc_normal = RaptorQFecEncoder::with_defaults(5);
|
||||||
|
enc_normal.set_keyframe_ratio(0.8);
|
||||||
|
for i in 0..5 {
|
||||||
|
enc_normal
|
||||||
|
.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
let normal_repair = enc_normal.generate_repair(0.2).unwrap();
|
||||||
|
assert_eq!(normal_repair.len(), 1);
|
||||||
|
|
||||||
|
// Keyframe block with same nominal ratio but boost to 0.8 → ceil(5 * 0.8) = 4 repairs.
|
||||||
|
let mut enc_key = RaptorQFecEncoder::with_defaults(5);
|
||||||
|
enc_key.set_keyframe_ratio(0.8);
|
||||||
|
for i in 0..5 {
|
||||||
|
enc_key
|
||||||
|
.add_source_symbol_with_keyframe(&[i as u8; 100], i == 2)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
let keyframe_repair = enc_key.generate_repair(0.2).unwrap();
|
||||||
|
assert_eq!(keyframe_repair.len(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_keyframe_block_uses_nominal_ratio() {
|
||||||
|
let mut enc = RaptorQFecEncoder::with_defaults(5);
|
||||||
|
enc.set_keyframe_ratio(0.8);
|
||||||
|
|
||||||
|
for i in 0..5 {
|
||||||
|
enc.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let repair = enc.generate_repair(0.2).unwrap();
|
||||||
|
assert_eq!(repair.len(), 1); // ceil(5 * 0.2) = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn finalize_clears_keyframe_flag() {
|
||||||
|
let mut enc = RaptorQFecEncoder::with_defaults(2);
|
||||||
|
enc.add_source_symbol_with_keyframe(&[0u8; 10], true)
|
||||||
|
.unwrap();
|
||||||
|
assert!(enc.has_keyframe());
|
||||||
|
|
||||||
|
enc.finalize_block().unwrap();
|
||||||
|
assert!(!enc.has_keyframe());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! rather than one block fatally.
|
//! rather than one block fatally.
|
||||||
|
|
||||||
/// A symbol ready for transmission: (block_id, symbol_index, is_repair, data).
|
/// A symbol ready for transmission: (block_id, symbol_index, is_repair, data).
|
||||||
pub type Symbol = (u8, u8, bool, Vec<u8>);
|
pub type Symbol = (u16, u16, bool, Vec<u8>);
|
||||||
|
|
||||||
/// Temporal interleaver that mixes symbols across multiple FEC blocks.
|
/// Temporal interleaver that mixes symbols across multiple FEC blocks.
|
||||||
pub struct Interleaver {
|
pub struct Interleaver {
|
||||||
@@ -64,13 +64,13 @@ mod tests {
|
|||||||
let interleaver = Interleaver::with_default_depth();
|
let interleaver = Interleaver::with_default_depth();
|
||||||
|
|
||||||
let block_a: Vec<Symbol> = (0..3)
|
let block_a: Vec<Symbol> = (0..3)
|
||||||
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8]))
|
.map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
|
||||||
.collect();
|
.collect();
|
||||||
let block_b: Vec<Symbol> = (0..3)
|
let block_b: Vec<Symbol> = (0..3)
|
||||||
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8]))
|
.map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
|
||||||
.collect();
|
.collect();
|
||||||
let block_c: Vec<Symbol> = (0..3)
|
let block_c: Vec<Symbol> = (0..3)
|
||||||
.map(|i| (2u8, i as u8, false, vec![0xC0 + i as u8]))
|
.map(|i| (2u16, i as u16, false, vec![0xC0 + i as u8]))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = interleaver.interleave(&[block_a, block_b, block_c]);
|
let result = interleaver.interleave(&[block_a, block_b, block_c]);
|
||||||
@@ -96,10 +96,10 @@ mod tests {
|
|||||||
let interleaver = Interleaver::new(2);
|
let interleaver = Interleaver::new(2);
|
||||||
|
|
||||||
let block_a: Vec<Symbol> = (0..3)
|
let block_a: Vec<Symbol> = (0..3)
|
||||||
.map(|i| (0u8, i as u8, false, vec![0xA0 + i as u8]))
|
.map(|i| (0u16, i as u16, false, vec![0xA0 + i as u8]))
|
||||||
.collect();
|
.collect();
|
||||||
let block_b: Vec<Symbol> = (0..1)
|
let block_b: Vec<Symbol> = (0..1)
|
||||||
.map(|i| (1u8, i as u8, false, vec![0xB0 + i as u8]))
|
.map(|i| (1u16, i as u16, false, vec![0xB0 + i as u8]))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let result = interleaver.interleave(&[block_a, block_b]);
|
let result = interleaver.interleave(&[block_a, block_b]);
|
||||||
@@ -128,7 +128,7 @@ mod tests {
|
|||||||
let blocks: Vec<Vec<Symbol>> = (0..3)
|
let blocks: Vec<Vec<Symbol>> = (0..3)
|
||||||
.map(|b| {
|
.map(|b| {
|
||||||
(0..6)
|
(0..6)
|
||||||
.map(|i| (b as u8, i as u8, false, vec![b as u8; 10]))
|
.map(|i| (b as u16, i as u16, false, vec![b as u8; 10]))
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -146,7 +146,10 @@ mod tests {
|
|||||||
|
|
||||||
// Each block should lose exactly 2 (6 losses / 3 blocks)
|
// Each block should lose exactly 2 (6 losses / 3 blocks)
|
||||||
for &loss in &losses_per_block {
|
for &loss in &losses_per_block {
|
||||||
assert_eq!(loss, 2, "Each block should lose at most 2 symbols from a burst of 6");
|
assert_eq!(
|
||||||
|
loss, 2,
|
||||||
|
"Each block should lose at most 2 symbols from a burst of 6"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ pub mod encoder;
|
|||||||
pub mod interleave;
|
pub mod interleave;
|
||||||
|
|
||||||
pub use adaptive::AdaptiveFec;
|
pub use adaptive::AdaptiveFec;
|
||||||
pub use block_manager::{DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState};
|
pub use block_manager::{
|
||||||
|
DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState,
|
||||||
|
};
|
||||||
pub use decoder::RaptorQFecDecoder;
|
pub use decoder::RaptorQFecDecoder;
|
||||||
pub use encoder::RaptorQFecEncoder;
|
pub use encoder::RaptorQFecEncoder;
|
||||||
pub use interleave::Interleaver;
|
pub use interleave::Interleaver;
|
||||||
@@ -24,9 +26,7 @@ pub use interleave::Interleaver;
|
|||||||
pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile};
|
pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile};
|
||||||
|
|
||||||
/// Create an encoder/decoder pair configured for the given quality profile.
|
/// Create an encoder/decoder pair configured for the given quality profile.
|
||||||
pub fn create_fec_pair(
|
pub fn create_fec_pair(profile: &QualityProfile) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
||||||
profile: &QualityProfile,
|
|
||||||
) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
|
||||||
let cfg = AdaptiveFec::from_profile(profile);
|
let cfg = AdaptiveFec::from_profile(profile);
|
||||||
let encoder = cfg.build_encoder();
|
let encoder = cfg.build_encoder();
|
||||||
let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size);
|
let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size);
|
||||||
|
|||||||
@@ -24,7 +24,10 @@ fn main() {
|
|||||||
let oboe_dir = fetch_oboe();
|
let oboe_dir = fetch_oboe();
|
||||||
match oboe_dir {
|
match oboe_dir {
|
||||||
Some(oboe_path) => {
|
Some(oboe_path) => {
|
||||||
println!("cargo:warning=wzp-native: building with Oboe from {:?}", oboe_path);
|
println!(
|
||||||
|
"cargo:warning=wzp-native: building with Oboe from {:?}",
|
||||||
|
oboe_path
|
||||||
|
);
|
||||||
let mut build = cc::Build::new();
|
let mut build = cc::Build::new();
|
||||||
build
|
build
|
||||||
.cpp(true)
|
.cpp(true)
|
||||||
@@ -96,7 +99,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
|||||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||||
let oboe_dir = out_dir.join("oboe");
|
let oboe_dir = out_dir.join("oboe");
|
||||||
|
|
||||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
if oboe_dir
|
||||||
|
.join("include")
|
||||||
|
.join("oboe")
|
||||||
|
.join("Oboe.h")
|
||||||
|
.exists()
|
||||||
|
{
|
||||||
return Some(oboe_dir);
|
return Some(oboe_dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +119,14 @@ fn fetch_oboe() -> Option<PathBuf> {
|
|||||||
.status();
|
.status();
|
||||||
|
|
||||||
match status {
|
match status {
|
||||||
Ok(s) if s.success() && oboe_dir.join("include").join("oboe").join("Oboe.h").exists() => {
|
Ok(s)
|
||||||
|
if s.success()
|
||||||
|
&& oboe_dir
|
||||||
|
.join("include")
|
||||||
|
.join("oboe")
|
||||||
|
.join("Oboe.h")
|
||||||
|
.exists() =>
|
||||||
|
{
|
||||||
Some(oboe_dir)
|
Some(oboe_dir)
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
#define LOG_TAG "wzp-oboe"
|
#define LOG_TAG "wzp-oboe"
|
||||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||||
@@ -388,6 +390,52 @@ int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
|||||||
return -5;
|
return -5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log initial stream states right after requestStart() returns.
|
||||||
|
// On well-behaved HALs both will already be Started; on others
|
||||||
|
// (Nothing A059) they may still be in Starting state.
|
||||||
|
LOGI("requestStart returned: capture_state=%d playout_state=%d",
|
||||||
|
(int)g_capture_stream->getState(),
|
||||||
|
(int)g_playout_stream->getState());
|
||||||
|
|
||||||
|
// Poll until both streams report Started state, up to 2s timeout.
|
||||||
|
// Some Android HALs (Nothing A059) delay transitioning from Starting
|
||||||
|
// to Started; proceeding before the transition completes causes the
|
||||||
|
// first capture/playout callbacks to be dropped silently.
|
||||||
|
{
|
||||||
|
auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(2000);
|
||||||
|
int poll_count = 0;
|
||||||
|
bool streams_started = false;
|
||||||
|
while (std::chrono::steady_clock::now() < deadline) {
|
||||||
|
auto cap_state = g_capture_stream->getState();
|
||||||
|
auto play_state = g_playout_stream->getState();
|
||||||
|
if (cap_state == oboe::StreamState::Started &&
|
||||||
|
play_state == oboe::StreamState::Started) {
|
||||||
|
LOGI("both streams Started after %d polls", poll_count);
|
||||||
|
streams_started = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
poll_count++;
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
}
|
||||||
|
// Log final state even on timeout (helps diagnose HAL quirks)
|
||||||
|
LOGI("stream states after poll: capture=%d playout=%d (polls=%d)",
|
||||||
|
(int)g_capture_stream->getState(),
|
||||||
|
(int)g_playout_stream->getState(),
|
||||||
|
poll_count);
|
||||||
|
if (!streams_started) {
|
||||||
|
LOGE("Timed out waiting for Oboe streams to reach Started state");
|
||||||
|
g_running.store(false, std::memory_order_release);
|
||||||
|
g_rings_valid.store(false, std::memory_order_release);
|
||||||
|
g_capture_stream->requestStop();
|
||||||
|
g_playout_stream->requestStop();
|
||||||
|
g_capture_stream->close();
|
||||||
|
g_playout_stream->close();
|
||||||
|
g_capture_stream.reset();
|
||||||
|
g_playout_stream.reset();
|
||||||
|
return -6;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LOGI("Oboe started: sr=%d burst=%d ch=%d",
|
LOGI("Oboe started: sr=%d burst=%d ch=%d",
|
||||||
config->sample_rate, config->frames_per_burst, config->channel_count);
|
config->sample_rate, config->frames_per_burst, config->channel_count);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@@ -26,6 +26,11 @@ pub extern "C" fn wzp_native_version() -> i32 {
|
|||||||
|
|
||||||
/// Writes a NUL-terminated string into `out` (capped at `cap`) and
|
/// Writes a NUL-terminated string into `out` (capped at `cap`) and
|
||||||
/// returns bytes written excluding the NUL.
|
/// returns bytes written excluding the NUL.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// `out` must be a valid pointer to at least `cap` contiguous bytes of
|
||||||
|
/// writable memory. Passing a null pointer or zero capacity is safe
|
||||||
|
/// (returns 0), but a dangling non-null pointer is undefined behaviour.
|
||||||
#[unsafe(no_mangle)]
|
#[unsafe(no_mangle)]
|
||||||
pub unsafe extern "C" fn wzp_native_hello(out: *mut u8, cap: usize) -> usize {
|
pub unsafe extern "C" fn wzp_native_hello(out: *mut u8, cap: usize) -> usize {
|
||||||
const MSG: &[u8] = b"hello from wzp-native\0";
|
const MSG: &[u8] = b"hello from wzp-native\0";
|
||||||
@@ -111,7 +116,11 @@ impl RingBuffer {
|
|||||||
let w = self.write_idx.load(Ordering::Acquire);
|
let w = self.write_idx.load(Ordering::Acquire);
|
||||||
let r = self.read_idx.load(Ordering::Relaxed);
|
let r = self.read_idx.load(Ordering::Relaxed);
|
||||||
let avail = w - r;
|
let avail = w - r;
|
||||||
if avail < 0 { (avail + self.capacity as i32) as usize } else { avail as usize }
|
if avail < 0 {
|
||||||
|
(avail + self.capacity as i32) as usize
|
||||||
|
} else {
|
||||||
|
avail as usize
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn available_write(&self) -> usize {
|
fn available_write(&self) -> usize {
|
||||||
@@ -127,9 +136,13 @@ impl RingBuffer {
|
|||||||
let cap = self.capacity;
|
let cap = self.capacity;
|
||||||
let buf_ptr = self.buf.as_ptr() as *mut i16;
|
let buf_ptr = self.buf.as_ptr() as *mut i16;
|
||||||
for sample in &data[..count] {
|
for sample in &data[..count] {
|
||||||
unsafe { *buf_ptr.add(w) = *sample; }
|
unsafe {
|
||||||
|
*buf_ptr.add(w) = *sample;
|
||||||
|
}
|
||||||
w += 1;
|
w += 1;
|
||||||
if w >= cap { w = 0; }
|
if w >= cap {
|
||||||
|
w = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
self.write_idx.store(w as i32, Ordering::Release);
|
self.write_idx.store(w as i32, Ordering::Release);
|
||||||
count
|
count
|
||||||
@@ -144,9 +157,13 @@ impl RingBuffer {
|
|||||||
let cap = self.capacity;
|
let cap = self.capacity;
|
||||||
let buf_ptr = self.buf.as_ptr();
|
let buf_ptr = self.buf.as_ptr();
|
||||||
for slot in &mut out[..count] {
|
for slot in &mut out[..count] {
|
||||||
unsafe { *slot = *buf_ptr.add(r); }
|
unsafe {
|
||||||
|
*slot = *buf_ptr.add(r);
|
||||||
|
}
|
||||||
r += 1;
|
r += 1;
|
||||||
if r >= cap { r = 0; }
|
if r >= cap {
|
||||||
|
r = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
self.read_idx.store(r as i32, Ordering::Release);
|
self.read_idx.store(r as i32, Ordering::Release);
|
||||||
count
|
count
|
||||||
@@ -264,9 +281,20 @@ pub extern "C" fn wzp_native_audio_stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Number of capture samples available to read without blocking.
|
||||||
|
#[unsafe(no_mangle)]
|
||||||
|
pub extern "C" fn wzp_native_audio_capture_available() -> usize {
|
||||||
|
backend().capture.available_read()
|
||||||
|
}
|
||||||
|
|
||||||
/// Read captured PCM samples from the capture ring. Returns the number
|
/// Read captured PCM samples from the capture ring. Returns the number
|
||||||
/// of `i16` samples actually copied into `out` (may be less than
|
/// of `i16` samples actually copied into `out` (may be less than
|
||||||
/// `out_len` if the ring is empty).
|
/// `out_len` if the ring is empty).
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// `out` must be a valid pointer to `out_len` contiguous `i16` values.
|
||||||
|
/// The caller must ensure no other thread writes to the same buffer
|
||||||
|
/// concurrently. Passing a null pointer or zero length is safe (returns 0).
|
||||||
#[unsafe(no_mangle)]
|
#[unsafe(no_mangle)]
|
||||||
pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: usize) -> usize {
|
pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: usize) -> usize {
|
||||||
if out.is_null() || out_len == 0 {
|
if out.is_null() || out_len == 0 {
|
||||||
@@ -280,6 +308,12 @@ pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: u
|
|||||||
/// samples actually enqueued (may be less than `in_len` if the ring
|
/// samples actually enqueued (may be less than `in_len` if the ring
|
||||||
/// is nearly full — in practice the caller should pace to 20 ms
|
/// is nearly full — in practice the caller should pace to 20 ms
|
||||||
/// frames and spin briefly if the ring is full).
|
/// frames and spin briefly if the ring is full).
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// `input` must be a valid pointer to `in_len` contiguous `i16` values
|
||||||
|
/// that remain valid for the duration of the call. Passing a null pointer
|
||||||
|
/// or zero length is safe (returns 0). The caller must not free or mutate
|
||||||
|
/// the buffer while this function is executing.
|
||||||
#[unsafe(no_mangle)]
|
#[unsafe(no_mangle)]
|
||||||
pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_len: usize) -> usize {
|
pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_len: usize) -> usize {
|
||||||
if input.is_null() || in_len == 0 {
|
if input.is_null() || in_len == 0 {
|
||||||
@@ -294,17 +328,27 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
|||||||
// has stopped firing → restart the streams. This is the
|
// has stopped firing → restart the streams. This is the
|
||||||
// self-healing behavior that makes rejoin work: teardown +
|
// self-healing behavior that makes rejoin work: teardown +
|
||||||
// rebuild clears whatever HAL state locked up the callback.
|
// rebuild clears whatever HAL state locked up the callback.
|
||||||
let current_read_idx = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
let current_read_idx = b
|
||||||
let last_read_idx = b.playout_last_read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
.playout
|
||||||
|
.read_idx
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let last_read_idx = b
|
||||||
|
.playout_last_read_idx
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
if current_read_idx == last_read_idx {
|
if current_read_idx == last_read_idx {
|
||||||
let stall = b.playout_stall_writes.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
let stall = b
|
||||||
|
.playout_stall_writes
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
if stall >= 50 {
|
if stall >= 50 {
|
||||||
// Callback hasn't drained anything in ~1 second.
|
// Callback hasn't drained anything in ~1 second.
|
||||||
// Force a stream restart.
|
// Force a stream restart.
|
||||||
unsafe {
|
unsafe {
|
||||||
android_log("playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams");
|
android_log(
|
||||||
|
"playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed);
|
b.playout_stall_writes
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
// Release the started lock, stop, re-start.
|
// Release the started lock, stop, re-start.
|
||||||
// This is the same logic as the Rust-side
|
// This is the same logic as the Rust-side
|
||||||
// audio_stop() + audio_start() but done inline
|
// audio_stop() + audio_start() but done inline
|
||||||
@@ -319,10 +363,18 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Clear the rings so the restart doesn't read stale data
|
// Clear the rings so the restart doesn't read stale data
|
||||||
b.playout.write_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
b.playout
|
||||||
b.playout.read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
.write_idx
|
||||||
b.capture.write_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
b.capture.read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
b.playout
|
||||||
|
.read_idx
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
b.capture
|
||||||
|
.write_idx
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
b.capture
|
||||||
|
.read_idx
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
// Re-start (stall detector — always non-BT mode)
|
// Re-start (stall detector — always non-BT mode)
|
||||||
let config = WzpOboeConfig {
|
let config = WzpOboeConfig {
|
||||||
sample_rate: 48_000,
|
sample_rate: 48_000,
|
||||||
@@ -345,30 +397,49 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
|||||||
if let Ok(mut started) = b.started.lock() {
|
if let Ok(mut started) = b.started.lock() {
|
||||||
*started = true;
|
*started = true;
|
||||||
}
|
}
|
||||||
unsafe { android_log("playout restart OK — Oboe streams rebuilt"); }
|
unsafe {
|
||||||
} else {
|
android_log("playout restart OK — Oboe streams rebuilt");
|
||||||
unsafe { android_log(&format!("playout restart FAILED: {ret}")); }
|
|
||||||
}
|
}
|
||||||
b.playout_last_read_idx.store(0, std::sync::atomic::Ordering::Relaxed);
|
} else {
|
||||||
|
unsafe {
|
||||||
|
android_log(&format!("playout restart FAILED: {ret}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.playout_last_read_idx
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
return 0; // caller will retry on next frame
|
return 0; // caller will retry on next frame
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// read_idx advanced — callback is alive, reset counter
|
// read_idx advanced — callback is alive, reset counter
|
||||||
b.playout_stall_writes.store(0, std::sync::atomic::Ordering::Relaxed);
|
b.playout_stall_writes
|
||||||
b.playout_last_read_idx.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
b.playout_last_read_idx
|
||||||
|
.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
let before_w = b.playout.write_idx.load(std::sync::atomic::Ordering::Relaxed);
|
let before_w = b
|
||||||
let before_r = b.playout.read_idx.load(std::sync::atomic::Ordering::Relaxed);
|
.playout
|
||||||
|
.write_idx
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let before_r = b
|
||||||
|
.playout
|
||||||
|
.read_idx
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let written = b.playout.write(slice);
|
let written = b.playout.write(slice);
|
||||||
// First few writes: log ring state + sample range so we can compare what
|
// First few writes: log ring state + sample range so we can compare what
|
||||||
// engine.rs hands us to what the C++ playout callback reads.
|
// engine.rs hands us to what the C++ playout callback reads.
|
||||||
let first_writes = b.playout_write_log_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
let first_writes = b
|
||||||
|
.playout_write_log_count
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
if first_writes < 3 || first_writes % 50 == 0 {
|
if first_writes < 3 || first_writes % 50 == 0 {
|
||||||
let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64);
|
let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64);
|
||||||
for &s in slice.iter() {
|
for &s in slice.iter() {
|
||||||
if s < lo { lo = s; }
|
if s < lo {
|
||||||
if s > hi { hi = s; }
|
lo = s;
|
||||||
|
}
|
||||||
|
if s > hi {
|
||||||
|
hi = s;
|
||||||
|
}
|
||||||
sumsq += (s as i64) * (s as i64);
|
sumsq += (s as i64) * (s as i64);
|
||||||
}
|
}
|
||||||
let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32;
|
let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32;
|
||||||
@@ -376,7 +447,8 @@ pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_le
|
|||||||
let avail_r_after = b.playout.available_read();
|
let avail_r_after = b.playout.available_read();
|
||||||
let msg = format!(
|
let msg = format!(
|
||||||
"playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}",
|
"playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}",
|
||||||
slice.len(), written
|
slice.len(),
|
||||||
|
written
|
||||||
);
|
);
|
||||||
unsafe {
|
unsafe {
|
||||||
android_log(msg.as_str());
|
android_log(msg.as_str());
|
||||||
@@ -400,7 +472,9 @@ unsafe fn android_log(msg: &str) {
|
|||||||
let mut buf = Vec::with_capacity(msg.len() + 1);
|
let mut buf = Vec::with_capacity(msg.len() + 1);
|
||||||
buf.extend_from_slice(msg.as_bytes());
|
buf.extend_from_slice(msg.as_bytes());
|
||||||
buf.push(0);
|
buf.push(0);
|
||||||
unsafe { __android_log_write(4, tag.as_ptr(), buf.as_ptr()); }
|
unsafe {
|
||||||
|
__android_log_write(4, tag.as_ptr(), buf.as_ptr());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_os = "android"))]
|
#[cfg(not(target_os = "android"))]
|
||||||
|
|||||||
@@ -20,3 +20,4 @@ tracing = "0.1"
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
bincode = "1"
|
||||||
|
|||||||
@@ -7,10 +7,11 @@
|
|||||||
//! Control (GCC).
|
//! Control (GCC).
|
||||||
|
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::time::Instant;
|
use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
|
||||||
|
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use crate::packet::QualityReport;
|
|
||||||
use crate::QualityProfile;
|
use crate::QualityProfile;
|
||||||
|
use crate::packet::QualityReport;
|
||||||
|
|
||||||
/// Network congestion state derived from delay and loss signals.
|
/// Network congestion state derived from delay and loss signals.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
@@ -158,6 +159,16 @@ pub struct BandwidthEstimator {
|
|||||||
loss_detector: LossBasedDetector,
|
loss_detector: LossBasedDetector,
|
||||||
/// Last update timestamp.
|
/// Last update timestamp.
|
||||||
last_update: Option<Instant>,
|
last_update: Option<Instant>,
|
||||||
|
|
||||||
|
// ── Transport-feedback BWE (T2.2) ──
|
||||||
|
/// Congestion-window-derived bandwidth estimate in bits per second.
|
||||||
|
cwnd_bps: AtomicU64,
|
||||||
|
/// Peer REMB (Receiver Estimated Maximum Bitrate) in bits per second.
|
||||||
|
peer_remb_bps: AtomicU64,
|
||||||
|
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||||
|
smoothed_bps: AtomicU64,
|
||||||
|
/// Last time `smoothed_bps` was updated (UNIX epoch millis).
|
||||||
|
last_smoothed_ms: AtomicU64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Multiplicative decrease factor applied on congestion (15% reduction).
|
/// Multiplicative decrease factor applied on congestion (15% reduction).
|
||||||
@@ -179,6 +190,10 @@ impl BandwidthEstimator {
|
|||||||
delay_detector: DelayBasedDetector::new(),
|
delay_detector: DelayBasedDetector::new(),
|
||||||
loss_detector: LossBasedDetector::new(),
|
loss_detector: LossBasedDetector::new(),
|
||||||
last_update: None,
|
last_update: None,
|
||||||
|
cwnd_bps: AtomicU64::new(0),
|
||||||
|
peer_remb_bps: AtomicU64::new(u64::MAX),
|
||||||
|
smoothed_bps: AtomicU64::new(0),
|
||||||
|
last_smoothed_ms: AtomicU64::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,6 +265,64 @@ impl BandwidthEstimator {
|
|||||||
QualityProfile::CATASTROPHIC
|
QualityProfile::CATASTROPHIC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Transport-feedback BWE (T2.2) ──
|
||||||
|
|
||||||
|
/// Update from QUIC path stats.
|
||||||
|
///
|
||||||
|
/// Computes `cwnd_bps = cwnd_bytes * 8 / rtt_s` and feeds it into the
|
||||||
|
/// smoothed estimate.
|
||||||
|
pub fn update_from_path(&self, cwnd_bytes: u64, _bytes_in_flight: u64, rtt_ms: u32) {
|
||||||
|
let rtt_s = rtt_ms.max(1) as f64 / 1000.0;
|
||||||
|
let cwnd_bps = ((cwnd_bytes * 8) as f64 / rtt_s) as u64;
|
||||||
|
self.cwnd_bps.store(cwnd_bps, Relaxed);
|
||||||
|
self.update_smoothed(cwnd_bps);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update from a peer's `TransportFeedback` REMB value.
|
||||||
|
pub fn update_from_peer(&self, fb_remb_bps: u32) {
|
||||||
|
let remb = fb_remb_bps as u64;
|
||||||
|
self.peer_remb_bps.store(remb, Relaxed);
|
||||||
|
self.update_smoothed(remb);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Target sending bitrate in bits per second.
|
||||||
|
///
|
||||||
|
/// Returns 90% of the minimum between the congestion-window estimate
|
||||||
|
/// and the peer REMB estimate.
|
||||||
|
pub fn target_send_bps(&self) -> u64 {
|
||||||
|
let cwnd = self.cwnd_bps.load(Relaxed);
|
||||||
|
let remb = self.peer_remb_bps.load(Relaxed);
|
||||||
|
let m = cwnd.min(remb);
|
||||||
|
(m as f64 * 0.9) as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||||
|
pub fn smoothed_bps(&self) -> u64 {
|
||||||
|
self.smoothed_bps.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply EWMA smoothing with a 2-second half-life.
|
||||||
|
fn update_smoothed(&self, new_bps: u64) {
|
||||||
|
let now_ms = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64;
|
||||||
|
let last_ms = self.last_smoothed_ms.load(Relaxed);
|
||||||
|
let dt_ms = now_ms.saturating_sub(last_ms);
|
||||||
|
|
||||||
|
let current = self.smoothed_bps.load(Relaxed);
|
||||||
|
let updated = if current == 0 || dt_ms == 0 {
|
||||||
|
new_bps
|
||||||
|
} else {
|
||||||
|
let alpha = 1.0 - 0.5_f64.powf(dt_ms as f64 / 2000.0);
|
||||||
|
let s = current as f64 * (1.0 - alpha) + new_bps as f64 * alpha;
|
||||||
|
s as u64
|
||||||
|
};
|
||||||
|
|
||||||
|
self.smoothed_bps.store(updated, Relaxed);
|
||||||
|
self.last_smoothed_ms.store(now_ms, Relaxed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -396,10 +469,7 @@ mod tests {
|
|||||||
|
|
||||||
// Below 8 => CATASTROPHIC
|
// Below 8 => CATASTROPHIC
|
||||||
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
|
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
|
||||||
assert_eq!(
|
assert_eq!(bwe_cat.recommended_profile(), QualityProfile::CATASTROPHIC);
|
||||||
bwe_cat.recommended_profile(),
|
|
||||||
QualityProfile::CATASTROPHIC
|
|
||||||
);
|
|
||||||
|
|
||||||
// High bandwidth
|
// High bandwidth
|
||||||
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
|
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
|
||||||
@@ -451,4 +521,46 @@ mod tests {
|
|||||||
}
|
}
|
||||||
assert!(det.is_congested());
|
assert!(det.is_congested());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn target_send_bps_uses_min_of_cwnd_and_remb() {
|
||||||
|
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||||
|
// cwnd_bps = 100_000, remb = 200_000 → min = 100_000 → 90%
|
||||||
|
bwe.update_from_path(1250, 0, 100); // 1250*8 / 0.1 = 100_000
|
||||||
|
bwe.update_from_peer(200_000);
|
||||||
|
assert_eq!(bwe.target_send_bps(), 90_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn target_send_bps_with_zero_cwnd_uses_remb() {
|
||||||
|
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||||
|
// Default cwnd is 0, remb is u64::MAX (default).
|
||||||
|
// 0.min(u64::MAX) = 0 → 90% = 0
|
||||||
|
assert_eq!(bwe.target_send_bps(), 0);
|
||||||
|
|
||||||
|
bwe.update_from_peer(100_000);
|
||||||
|
// cwnd still 0
|
||||||
|
assert_eq!(bwe.target_send_bps(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smoothed_bps_ewma_converges() {
|
||||||
|
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||||
|
bwe.update_from_path(1250, 0, 100); // 100_000 bps
|
||||||
|
let s1 = bwe.smoothed_bps();
|
||||||
|
assert_eq!(s1, 100_000);
|
||||||
|
|
||||||
|
// Immediately update with same value — dt ≈ 0, so should stay at 100_000
|
||||||
|
bwe.update_from_path(1250, 0, 100);
|
||||||
|
let s2 = bwe.smoothed_bps();
|
||||||
|
assert_eq!(s2, 100_000);
|
||||||
|
|
||||||
|
// Sleep a bit so dt is non-zero, then update with a much higher value.
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
bwe.update_from_path(12500, 0, 100); // 1_000_000 bps
|
||||||
|
let s3 = bwe.smoothed_bps();
|
||||||
|
assert!(s3 > 100_000, "smoothed should increase toward 1M: {s3}");
|
||||||
|
// With 100ms dt, alpha ≈ 0.03, so smoothed should be ~100k * 0.97 + 1M * 0.03 ≈ 127k
|
||||||
|
assert!(s3 < 500_000, "smoothed should not jump too far: {s3}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
/// Identifies the audio codec and bitrate configuration.
|
/// Identifies the audio codec and bitrate configuration.
|
||||||
///
|
///
|
||||||
/// Encoded as 4 bits in the media packet header.
|
/// Encoded as 4 bits in the v1 media packet header, and as a full 8-bit
|
||||||
|
/// value in the v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
pub enum CodecId {
|
pub enum CodecId {
|
||||||
@@ -24,6 +25,16 @@ pub enum CodecId {
|
|||||||
Opus48k = 7,
|
Opus48k = 7,
|
||||||
/// Opus at 64kbps (studio high)
|
/// Opus at 64kbps (studio high)
|
||||||
Opus64k = 8,
|
Opus64k = 8,
|
||||||
|
/// H.264 baseline profile (video).
|
||||||
|
H264Baseline = 9,
|
||||||
|
// Reserved for video codecs; implementations land in PRD-video-multicodec.
|
||||||
|
// 10 => H264 main
|
||||||
|
// 11 => H265 main
|
||||||
|
// 13 => VP9
|
||||||
|
/// AV1 main profile (video).
|
||||||
|
Av1Main = 12,
|
||||||
|
/// H.265 main profile (video).
|
||||||
|
H265Main = 11,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CodecId {
|
impl CodecId {
|
||||||
@@ -39,6 +50,7 @@ impl CodecId {
|
|||||||
Self::Codec2_3200 => 3_200,
|
Self::Codec2_3200 => 3_200,
|
||||||
Self::Codec2_1200 => 1_200,
|
Self::Codec2_1200 => 1_200,
|
||||||
Self::ComfortNoise => 0,
|
Self::ComfortNoise => 0,
|
||||||
|
Self::H264Baseline | Self::H265Main | Self::Av1Main => 2_000_000,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,16 +62,22 @@ impl CodecId {
|
|||||||
Self::Codec2_3200 => 20,
|
Self::Codec2_3200 => 20,
|
||||||
Self::Codec2_1200 => 40,
|
Self::Codec2_1200 => 40,
|
||||||
Self::ComfortNoise => 20,
|
Self::ComfortNoise => 20,
|
||||||
|
Self::H264Baseline | Self::H265Main | Self::Av1Main => 33,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sample rate expected by this codec.
|
/// Sample rate expected by this codec.
|
||||||
pub const fn sample_rate_hz(self) -> u32 {
|
pub const fn sample_rate_hz(self) -> u32 {
|
||||||
match self {
|
match self {
|
||||||
Self::Opus24k | Self::Opus16k | Self::Opus6k
|
Self::Opus24k
|
||||||
| Self::Opus32k | Self::Opus48k | Self::Opus64k => 48_000,
|
| Self::Opus16k
|
||||||
|
| Self::Opus6k
|
||||||
|
| Self::Opus32k
|
||||||
|
| Self::Opus48k
|
||||||
|
| Self::Opus64k => 48_000,
|
||||||
Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
|
Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
|
||||||
Self::ComfortNoise => 48_000,
|
Self::ComfortNoise => 48_000,
|
||||||
|
Self::H264Baseline | Self::H265Main | Self::Av1Main => 48_000,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,6 +93,9 @@ impl CodecId {
|
|||||||
6 => Some(Self::Opus32k),
|
6 => Some(Self::Opus32k),
|
||||||
7 => Some(Self::Opus48k),
|
7 => Some(Self::Opus48k),
|
||||||
8 => Some(Self::Opus64k),
|
8 => Some(Self::Opus64k),
|
||||||
|
9 => Some(Self::H264Baseline),
|
||||||
|
11 => Some(Self::H265Main),
|
||||||
|
12 => Some(Self::Av1Main),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -84,10 +105,22 @@ impl CodecId {
|
|||||||
self as u8
|
self as u8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a video codec variant.
|
||||||
|
pub const fn is_video(self) -> bool {
|
||||||
|
matches!(self, Self::H264Baseline | Self::H265Main | Self::Av1Main)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if this is an Opus variant.
|
/// Returns true if this is an Opus variant.
|
||||||
pub const fn is_opus(self) -> bool {
|
pub const fn is_opus(self) -> bool {
|
||||||
matches!(self, Self::Opus6k | Self::Opus16k | Self::Opus24k
|
matches!(
|
||||||
| Self::Opus32k | Self::Opus48k | Self::Opus64k)
|
self,
|
||||||
|
Self::Opus6k
|
||||||
|
| Self::Opus16k
|
||||||
|
| Self::Opus24k
|
||||||
|
| Self::Opus32k
|
||||||
|
| Self::Opus48k
|
||||||
|
| Self::Opus64k
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,6 +135,18 @@ pub struct QualityProfile {
|
|||||||
pub frame_duration_ms: u8,
|
pub frame_duration_ms: u8,
|
||||||
/// Number of source frames per FEC block.
|
/// Number of source frames per FEC block.
|
||||||
pub frames_per_block: u8,
|
pub frames_per_block: u8,
|
||||||
|
/// Bandwidth-allocation priority between audio and video.
|
||||||
|
#[serde(default)]
|
||||||
|
pub priority_mode: crate::PriorityMode,
|
||||||
|
/// Target video bitrate in kbps (set by quality controller, not handshake).
|
||||||
|
#[serde(default)]
|
||||||
|
pub video_bitrate_kbps: Option<u32>,
|
||||||
|
/// Target video resolution as (width, height).
|
||||||
|
#[serde(default)]
|
||||||
|
pub video_resolution: Option<(u16, u16)>,
|
||||||
|
/// Target video frame rate.
|
||||||
|
#[serde(default)]
|
||||||
|
pub video_fps: Option<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QualityProfile {
|
impl QualityProfile {
|
||||||
@@ -111,6 +156,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 0.2,
|
fec_ratio: 0.2,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Degraded conditions: Opus 6kbps, moderate FEC.
|
/// Degraded conditions: Opus 6kbps, moderate FEC.
|
||||||
@@ -119,6 +168,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 0.5,
|
fec_ratio: 0.5,
|
||||||
frame_duration_ms: 40,
|
frame_duration_ms: 40,
|
||||||
frames_per_block: 10,
|
frames_per_block: 10,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Catastrophic conditions: Codec2 1.2kbps, heavy FEC.
|
/// Catastrophic conditions: Codec2 1.2kbps, heavy FEC.
|
||||||
@@ -127,6 +180,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 1.0,
|
fec_ratio: 1.0,
|
||||||
frame_duration_ms: 40,
|
frame_duration_ms: 40,
|
||||||
frames_per_block: 8,
|
frames_per_block: 8,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Studio low: Opus 32kbps, minimal FEC.
|
/// Studio low: Opus 32kbps, minimal FEC.
|
||||||
@@ -135,6 +192,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 0.1,
|
fec_ratio: 0.1,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Studio: Opus 48kbps, minimal FEC.
|
/// Studio: Opus 48kbps, minimal FEC.
|
||||||
@@ -143,6 +204,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 0.1,
|
fec_ratio: 0.1,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Studio high: Opus 64kbps, minimal FEC.
|
/// Studio high: Opus 64kbps, minimal FEC.
|
||||||
@@ -151,6 +216,10 @@ impl QualityProfile {
|
|||||||
fec_ratio: 0.1,
|
fec_ratio: 0.1,
|
||||||
frame_duration_ms: 20,
|
frame_duration_ms: 20,
|
||||||
frames_per_block: 5,
|
frames_per_block: 5,
|
||||||
|
priority_mode: crate::PriorityMode::AudioFirst,
|
||||||
|
video_bitrate_kbps: None,
|
||||||
|
video_resolution: None,
|
||||||
|
video_fps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Estimated total bandwidth in kbps including FEC overhead.
|
/// Estimated total bandwidth in kbps including FEC overhead.
|
||||||
@@ -159,3 +228,46 @@ impl QualityProfile {
|
|||||||
base * (1.0 + self.fec_ratio)
|
base * (1.0 + self.fec_ratio)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{CodecId, QualityProfile};
|
||||||
|
use crate::PriorityMode;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn codec_id_unknown_values_rejected() {
|
||||||
|
for v in [10u8, 13].iter().copied().chain(14u8..=255) {
|
||||||
|
assert!(CodecId::from_wire(v).is_none(), "v={v}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn h265_main_roundtrips() {
|
||||||
|
assert_eq!(CodecId::H265Main.to_wire(), 11);
|
||||||
|
assert_eq!(CodecId::from_wire(11), Some(CodecId::H265Main));
|
||||||
|
assert!(CodecId::H265Main.is_video());
|
||||||
|
assert_eq!(CodecId::H265Main.bitrate_bps(), 2_000_000);
|
||||||
|
assert_eq!(CodecId::H265Main.frame_duration_ms(), 33);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn av1_main_roundtrips() {
|
||||||
|
assert_eq!(CodecId::Av1Main.to_wire(), 12);
|
||||||
|
assert_eq!(CodecId::from_wire(12), Some(CodecId::Av1Main));
|
||||||
|
assert!(CodecId::Av1Main.is_video());
|
||||||
|
assert_eq!(CodecId::Av1Main.bitrate_bps(), 2_000_000);
|
||||||
|
assert_eq!(CodecId::Av1Main.frame_duration_ms(), 33);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quality_profile_backward_compat_old_json() {
|
||||||
|
// Old JSON emitted before T5.1 has no priority_mode or video fields.
|
||||||
|
let old_json =
|
||||||
|
r#"{"codec":"Opus24k","fec_ratio":0.2,"frame_duration_ms":20,"frames_per_block":5}"#;
|
||||||
|
let parsed: QualityProfile = serde_json::from_str(old_json).unwrap();
|
||||||
|
assert_eq!(parsed.priority_mode, PriorityMode::AudioFirst);
|
||||||
|
assert_eq!(parsed.video_bitrate_kbps, None);
|
||||||
|
assert_eq!(parsed.video_resolution, None);
|
||||||
|
assert_eq!(parsed.video_fps, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
320
crates/wzp-proto/src/dred_tuner.rs
Normal file
320
crates/wzp-proto/src/dred_tuner.rs
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
//! Continuous DRED tuning from real-time network metrics.
|
||||||
|
//!
|
||||||
|
//! Instead of locking DRED duration to 3 discrete quality tiers (100/200/500 ms),
|
||||||
|
//! `DredTuner` maps live path quality metrics to a continuous DRED duration and
|
||||||
|
//! expected-loss hint, updated every N packets. This makes DRED reactive within
|
||||||
|
//! ~200 ms instead of waiting for 3+ consecutive bad quality reports to trigger
|
||||||
|
//! a full tier transition.
|
||||||
|
//!
|
||||||
|
//! The tuner also implements pre-emptive jitter-spike detection ("sawtooth"
|
||||||
|
//! prediction): when jitter variance spikes >30% over a 200 ms window — typical
|
||||||
|
//! of Starlink satellite handovers — it temporarily boosts DRED to the maximum
|
||||||
|
//! allowed for the current codec before packets actually start dropping.
|
||||||
|
//!
|
||||||
|
//! See also: [`crate::quality`] for discrete tier classification that drives
|
||||||
|
//! codec switching. DredTuner operates within a tier, adjusting DRED
|
||||||
|
//! parameters continuously based on live network metrics.
|
||||||
|
|
||||||
|
use crate::CodecId;
|
||||||
|
|
||||||
|
/// Output of a single tuning cycle.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub struct DredTuning {
|
||||||
|
/// DRED duration in 10 ms frame units (0–104). Passed directly to
|
||||||
|
/// `OpusEncoder::set_dred_duration()`.
|
||||||
|
pub dred_frames: u8,
|
||||||
|
/// Expected packet loss percentage (0–100). Passed to
|
||||||
|
/// `OpusEncoder::set_expected_loss()`. Floored at 15% by the encoder
|
||||||
|
/// itself, but we pass the real value so the encoder can override upward.
|
||||||
|
pub expected_loss_pct: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Minimum DRED frames for any Opus codec (matches DRED_LOSS_FLOOR_PCT logic:
|
||||||
|
/// at 15% loss, libopus 1.5 emits ~95 ms of DRED, which needs at least 10
|
||||||
|
/// frames configured to be useful).
|
||||||
|
const MIN_DRED_FRAMES: u8 = 5;
|
||||||
|
|
||||||
|
/// Maximum DRED frames libopus supports (104 × 10 ms = 1040 ms).
|
||||||
|
const MAX_DRED_FRAMES: u8 = 104;
|
||||||
|
|
||||||
|
/// Jitter variance spike ratio that triggers pre-emptive DRED boost.
|
||||||
|
const JITTER_SPIKE_RATIO: f32 = 1.3;
|
||||||
|
|
||||||
|
/// How many tuning cycles a jitter-spike boost persists (at 25 packets/cycle
|
||||||
|
/// and 20 ms/packet, 10 cycles ≈ 5 seconds).
|
||||||
|
const SPIKE_BOOST_COOLDOWN_CYCLES: u32 = 10;
|
||||||
|
|
||||||
|
/// Maps codec tier to its baseline DRED frames (used when network is healthy).
|
||||||
|
fn baseline_dred_frames(codec: CodecId) -> u8 {
|
||||||
|
match codec {
|
||||||
|
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10, // 100 ms
|
||||||
|
CodecId::Opus16k | CodecId::Opus24k => 20, // 200 ms
|
||||||
|
CodecId::Opus6k => 50, // 500 ms
|
||||||
|
_ => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maps codec tier to its maximum allowed DRED frames under spike/bad conditions.
|
||||||
|
fn max_dred_frames_for(codec: CodecId) -> u8 {
|
||||||
|
match codec {
|
||||||
|
// Studio: cap at 300 ms (don't waste bitrate on good links)
|
||||||
|
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 30,
|
||||||
|
// Normal: cap at 500 ms
|
||||||
|
CodecId::Opus16k | CodecId::Opus24k => 50,
|
||||||
|
// Degraded: allow full 1040 ms
|
||||||
|
CodecId::Opus6k => MAX_DRED_FRAMES,
|
||||||
|
_ => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Continuous DRED tuner driven by network path metrics.
|
||||||
|
pub struct DredTuner {
|
||||||
|
/// Current codec (determines baseline and ceiling).
|
||||||
|
codec: CodecId,
|
||||||
|
/// Last computed tuning output.
|
||||||
|
last_tuning: DredTuning,
|
||||||
|
/// EWMA-smoothed jitter for spike detection (in ms).
|
||||||
|
jitter_ewma: f32,
|
||||||
|
/// Remaining cooldown cycles for a jitter-spike boost.
|
||||||
|
spike_cooldown: u32,
|
||||||
|
/// Whether the tuner has received at least one observation.
|
||||||
|
initialized: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DredTuner {
|
||||||
|
/// Create a new tuner for the given codec.
|
||||||
|
pub fn new(codec: CodecId) -> Self {
|
||||||
|
let baseline = baseline_dred_frames(codec);
|
||||||
|
Self {
|
||||||
|
codec,
|
||||||
|
last_tuning: DredTuning {
|
||||||
|
dred_frames: baseline,
|
||||||
|
expected_loss_pct: 15, // match DRED_LOSS_FLOOR_PCT
|
||||||
|
},
|
||||||
|
jitter_ewma: 0.0,
|
||||||
|
spike_cooldown: 0,
|
||||||
|
initialized: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the active codec (e.g. on tier transition). Resets spike state.
|
||||||
|
pub fn set_codec(&mut self, codec: CodecId) {
|
||||||
|
self.codec = codec;
|
||||||
|
self.spike_cooldown = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feed network metrics and compute new DRED parameters.
|
||||||
|
///
|
||||||
|
/// Call this every tuning cycle (e.g. every 25 packets ≈ 500 ms at 20 ms
|
||||||
|
/// frame duration).
|
||||||
|
///
|
||||||
|
/// - `loss_pct`: observed packet loss (0.0–100.0)
|
||||||
|
/// - `rtt_ms`: smoothed round-trip time
|
||||||
|
/// - `jitter_ms`: current jitter estimate (RTT variance)
|
||||||
|
///
|
||||||
|
/// Returns `Some(tuning)` if the output changed, `None` if unchanged.
|
||||||
|
pub fn update(&mut self, loss_pct: f32, rtt_ms: u32, jitter_ms: u32) -> Option<DredTuning> {
|
||||||
|
if !self.codec.is_opus() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let baseline = baseline_dred_frames(self.codec);
|
||||||
|
let ceiling = max_dred_frames_for(self.codec);
|
||||||
|
|
||||||
|
// --- Jitter spike detection ---
|
||||||
|
let jitter_f = jitter_ms as f32;
|
||||||
|
if !self.initialized {
|
||||||
|
self.jitter_ewma = jitter_f;
|
||||||
|
self.initialized = true;
|
||||||
|
} else {
|
||||||
|
// Fast-up (alpha=0.3), slow-down (alpha=0.05) asymmetric EWMA
|
||||||
|
let alpha = if jitter_f > self.jitter_ewma {
|
||||||
|
0.3
|
||||||
|
} else {
|
||||||
|
0.05
|
||||||
|
};
|
||||||
|
self.jitter_ewma = alpha * jitter_f + (1.0 - alpha) * self.jitter_ewma;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect spike: instantaneous jitter > EWMA × 1.3
|
||||||
|
if self.jitter_ewma > 1.0 && jitter_f > self.jitter_ewma * JITTER_SPIKE_RATIO {
|
||||||
|
self.spike_cooldown = SPIKE_BOOST_COOLDOWN_CYCLES;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement cooldown
|
||||||
|
if self.spike_cooldown > 0 {
|
||||||
|
self.spike_cooldown -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Compute DRED frames ---
|
||||||
|
let dred_frames = if self.spike_cooldown > 0 {
|
||||||
|
// During spike boost: jump to ceiling
|
||||||
|
ceiling
|
||||||
|
} else {
|
||||||
|
// Continuous mapping: scale linearly between baseline and ceiling
|
||||||
|
// based on loss percentage.
|
||||||
|
// 0% loss → baseline
|
||||||
|
// 40% loss → ceiling
|
||||||
|
let loss_clamped = loss_pct.clamp(0.0, 40.0);
|
||||||
|
let t = loss_clamped / 40.0;
|
||||||
|
let raw = baseline as f32 + t * (ceiling - baseline) as f32;
|
||||||
|
(raw as u8).clamp(MIN_DRED_FRAMES, ceiling)
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Compute expected loss hint ---
|
||||||
|
// Pass the real loss so the encoder can clamp at its own floor (15%).
|
||||||
|
// For RTT-driven boost: high RTT suggests impending loss, so add a
|
||||||
|
// phantom loss contribution to keep DRED emitting generously.
|
||||||
|
let rtt_loss_phantom = if rtt_ms > 200 {
|
||||||
|
((rtt_ms - 200) as f32 / 40.0).min(15.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let expected_loss = (loss_pct + rtt_loss_phantom).clamp(0.0, 100.0) as u8;
|
||||||
|
|
||||||
|
let tuning = DredTuning {
|
||||||
|
dred_frames,
|
||||||
|
expected_loss_pct: expected_loss,
|
||||||
|
};
|
||||||
|
|
||||||
|
if tuning != self.last_tuning {
|
||||||
|
self.last_tuning = tuning;
|
||||||
|
Some(tuning)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the last computed tuning without updating.
|
||||||
|
pub fn current(&self) -> DredTuning {
|
||||||
|
self.last_tuning
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether a jitter-spike boost is currently active.
|
||||||
|
pub fn spike_boost_active(&self) -> bool {
|
||||||
|
self.spike_cooldown > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn baseline_for_opus24k() {
|
||||||
|
let tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
assert_eq!(tuner.current().dred_frames, 20); // 200 ms
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn baseline_for_opus6k() {
|
||||||
|
let tuner = DredTuner::new(CodecId::Opus6k);
|
||||||
|
assert_eq!(tuner.current().dred_frames, 50); // 500 ms
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn codec2_returns_none() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Codec2_1200);
|
||||||
|
assert!(tuner.update(10.0, 100, 20).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scales_with_loss() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// 0% loss → baseline (20 frames)
|
||||||
|
tuner.update(0.0, 50, 5);
|
||||||
|
assert_eq!(tuner.current().dred_frames, 20);
|
||||||
|
|
||||||
|
// 20% loss → midpoint between 20 and 50 = 35
|
||||||
|
tuner.update(20.0, 50, 5);
|
||||||
|
assert_eq!(tuner.current().dred_frames, 35);
|
||||||
|
|
||||||
|
// 40%+ loss → ceiling (50 frames)
|
||||||
|
tuner.update(40.0, 50, 5);
|
||||||
|
assert_eq!(tuner.current().dred_frames, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn jitter_spike_triggers_boost() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Establish baseline jitter
|
||||||
|
for _ in 0..20 {
|
||||||
|
tuner.update(0.0, 50, 10);
|
||||||
|
}
|
||||||
|
assert!(!tuner.spike_boost_active());
|
||||||
|
|
||||||
|
// Spike: jitter jumps to 50 ms (5x the EWMA of ~10)
|
||||||
|
tuner.update(0.0, 50, 50);
|
||||||
|
assert!(tuner.spike_boost_active());
|
||||||
|
// Should be at ceiling (50 frames = 500 ms for Opus24k)
|
||||||
|
assert_eq!(tuner.current().dred_frames, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spike_cooldown_decays() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Establish baseline then spike
|
||||||
|
for _ in 0..20 {
|
||||||
|
tuner.update(0.0, 50, 10);
|
||||||
|
}
|
||||||
|
tuner.update(0.0, 50, 50);
|
||||||
|
assert!(tuner.spike_boost_active());
|
||||||
|
|
||||||
|
// Run through cooldown
|
||||||
|
for _ in 0..SPIKE_BOOST_COOLDOWN_CYCLES {
|
||||||
|
tuner.update(0.0, 50, 10);
|
||||||
|
}
|
||||||
|
assert!(!tuner.spike_boost_active());
|
||||||
|
// Should return to baseline
|
||||||
|
assert_eq!(tuner.current().dred_frames, 20);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rtt_phantom_loss() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// High RTT (400ms) with 0% real loss
|
||||||
|
tuner.update(0.0, 400, 10);
|
||||||
|
// Phantom loss = (400-200)/40 = 5
|
||||||
|
assert_eq!(tuner.current().expected_loss_pct, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn set_codec_resets_spike() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Trigger spike
|
||||||
|
for _ in 0..20 {
|
||||||
|
tuner.update(0.0, 50, 10);
|
||||||
|
}
|
||||||
|
tuner.update(0.0, 50, 50);
|
||||||
|
assert!(tuner.spike_boost_active());
|
||||||
|
|
||||||
|
// Switch codec — spike should reset
|
||||||
|
tuner.set_codec(CodecId::Opus6k);
|
||||||
|
assert!(!tuner.spike_boost_active());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn opus6k_reaches_max_1040ms() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus6k);
|
||||||
|
|
||||||
|
// High loss → should reach 104 frames (1040 ms)
|
||||||
|
tuner.update(40.0, 50, 5);
|
||||||
|
assert_eq!(tuner.current().dred_frames, MAX_DRED_FRAMES);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_when_unchanged() {
|
||||||
|
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// First update always returns Some (initial → computed)
|
||||||
|
let first = tuner.update(0.0, 50, 5);
|
||||||
|
// Same inputs → None
|
||||||
|
let second = tuner.update(0.0, 50, 5);
|
||||||
|
assert!(first.is_some() || second.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,7 +37,7 @@ pub enum CryptoError {
|
|||||||
#[error("rekey failed: {0}")]
|
#[error("rekey failed: {0}")]
|
||||||
RekeyFailed(String),
|
RekeyFailed(String),
|
||||||
#[error("anti-replay: duplicate or old packet (seq={seq})")]
|
#[error("anti-replay: duplicate or old packet (seq={seq})")]
|
||||||
ReplayDetected { seq: u16 },
|
ReplayDetected { seq: u32 },
|
||||||
#[error("internal crypto error: {0}")]
|
#[error("internal crypto error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,9 +81,7 @@ impl AdaptivePlayoutDelay {
|
|||||||
let jitter = (actual_delta - expected_delta).abs();
|
let jitter = (actual_delta - expected_delta).abs();
|
||||||
|
|
||||||
// Spike detection: check before EMA update
|
// Spike detection: check before EMA update
|
||||||
if self.jitter_ema > 0.0
|
if self.jitter_ema > 0.0 && jitter > self.jitter_ema * self.spike_threshold_multiplier {
|
||||||
&& jitter > self.jitter_ema * self.spike_threshold_multiplier
|
|
||||||
{
|
|
||||||
self.spike_detected_at = Some(Instant::now());
|
self.spike_detected_at = Some(Instant::now());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,10 +105,8 @@ impl AdaptivePlayoutDelay {
|
|||||||
self.target_delay = self.max_delay;
|
self.target_delay = self.max_delay;
|
||||||
} else {
|
} else {
|
||||||
// Convert jitter estimate to target delay in packets
|
// Convert jitter estimate to target delay in packets
|
||||||
let raw_target =
|
let raw_target = (self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||||
(self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
self.target_delay = (raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||||
self.target_delay =
|
|
||||||
(raw_target as usize).clamp(self.min_delay, self.max_delay);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,9 +158,9 @@ impl AdaptivePlayoutDelay {
|
|||||||
/// Manages packet reordering, gap detection, and signals when PLC is needed.
|
/// Manages packet reordering, gap detection, and signals when PLC is needed.
|
||||||
pub struct JitterBuffer {
|
pub struct JitterBuffer {
|
||||||
/// Packets waiting to be consumed, ordered by sequence number.
|
/// Packets waiting to be consumed, ordered by sequence number.
|
||||||
buffer: BTreeMap<u16, MediaPacket>,
|
buffer: BTreeMap<u32, MediaPacket>,
|
||||||
/// Next sequence number expected for playout.
|
/// Next sequence number expected for playout.
|
||||||
next_playout_seq: u16,
|
next_playout_seq: u32,
|
||||||
/// Maximum buffer depth in number of packets.
|
/// Maximum buffer depth in number of packets.
|
||||||
max_depth: usize,
|
max_depth: usize,
|
||||||
/// Target buffer depth (adaptive, based on jitter).
|
/// Target buffer depth (adaptive, based on jitter).
|
||||||
@@ -204,7 +200,7 @@ pub enum PlayoutResult {
|
|||||||
/// A packet is available for playout.
|
/// A packet is available for playout.
|
||||||
Packet(MediaPacket),
|
Packet(MediaPacket),
|
||||||
/// The expected packet is missing — decoder should generate PLC.
|
/// The expected packet is missing — decoder should generate PLC.
|
||||||
Missing { seq: u16 },
|
Missing { seq: u32 },
|
||||||
/// Buffer is empty or not yet filled to target depth.
|
/// Buffer is empty or not yet filled to target depth.
|
||||||
NotReady,
|
NotReady,
|
||||||
}
|
}
|
||||||
@@ -278,9 +274,18 @@ impl JitterBuffer {
|
|||||||
// federation room — reset instead of dropping.
|
// federation room — reset instead of dropping.
|
||||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||||
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected");
|
tracing::warn!(
|
||||||
|
seq,
|
||||||
|
next = self.next_playout_seq,
|
||||||
|
backward_distance,
|
||||||
|
"jitter: backward seq detected"
|
||||||
|
);
|
||||||
if backward_distance > 100 {
|
if backward_distance > 100 {
|
||||||
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected");
|
tracing::info!(
|
||||||
|
seq,
|
||||||
|
next = self.next_playout_seq,
|
||||||
|
"jitter: RESET — new sender detected"
|
||||||
|
);
|
||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
self.next_playout_seq = seq;
|
self.next_playout_seq = seq;
|
||||||
self.stats.packets_late = 0;
|
self.stats.packets_late = 0;
|
||||||
@@ -428,9 +433,18 @@ impl JitterBuffer {
|
|||||||
// federation room — reset instead of dropping.
|
// federation room — reset instead of dropping.
|
||||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||||
tracing::warn!(seq, next = self.next_playout_seq, backward_distance, "jitter: backward seq detected");
|
tracing::warn!(
|
||||||
|
seq,
|
||||||
|
next = self.next_playout_seq,
|
||||||
|
backward_distance,
|
||||||
|
"jitter: backward seq detected"
|
||||||
|
);
|
||||||
if backward_distance > 100 {
|
if backward_distance > 100 {
|
||||||
tracing::info!(seq, next = self.next_playout_seq, "jitter: RESET — new sender detected");
|
tracing::info!(
|
||||||
|
seq,
|
||||||
|
next = self.next_playout_seq,
|
||||||
|
"jitter: RESET — new sender detected"
|
||||||
|
);
|
||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
self.next_playout_seq = seq;
|
self.next_playout_seq = seq;
|
||||||
self.stats.packets_late = 0;
|
self.stats.packets_late = 0;
|
||||||
@@ -489,7 +503,7 @@ impl JitterBuffer {
|
|||||||
|
|
||||||
/// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic).
|
/// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic).
|
||||||
/// Returns true if `a` comes before `b` in sequence space.
|
/// Returns true if `a` comes before `b` in sequence space.
|
||||||
fn seq_before(a: u16, b: u16) -> bool {
|
fn seq_before(a: u32, b: u32) -> bool {
|
||||||
let diff = b.wrapping_sub(a);
|
let diff = b.wrapping_sub(a);
|
||||||
diff > 0 && diff < 0x8000
|
diff > 0 && diff < 0x8000
|
||||||
}
|
}
|
||||||
@@ -497,24 +511,23 @@ fn seq_before(a: u16, b: u16) -> bool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::CodecId;
|
||||||
|
use crate::MediaType;
|
||||||
use crate::packet::{MediaHeader, MediaPacket};
|
use crate::packet::{MediaHeader, MediaPacket};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use crate::CodecId;
|
|
||||||
|
|
||||||
fn make_packet(seq: u16) -> MediaPacket {
|
fn make_packet(seq: u32) -> MediaPacket {
|
||||||
MediaPacket {
|
MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
codec_id: CodecId::Opus24k,
|
codec_id: CodecId::Opus24k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq,
|
seq,
|
||||||
timestamp: seq as u32 * 20,
|
timestamp: seq * 20,
|
||||||
fec_block: 0,
|
fec_block: 0,
|
||||||
fec_symbol: 0,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(vec![0u8; 60]),
|
payload: Bytes::from(vec![0u8; 60]),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -598,7 +611,7 @@ mod tests {
|
|||||||
fn seq_before_wrapping() {
|
fn seq_before_wrapping() {
|
||||||
assert!(seq_before(0, 1));
|
assert!(seq_before(0, 1));
|
||||||
assert!(seq_before(65534, 65535));
|
assert!(seq_before(65534, 65535));
|
||||||
assert!(seq_before(65535, 0)); // wrap
|
assert!(seq_before(u32::MAX, 0)); // wrap
|
||||||
assert!(!seq_before(1, 0));
|
assert!(!seq_before(1, 0));
|
||||||
assert!(!seq_before(5, 5)); // equal
|
assert!(!seq_before(5, 5)); // equal
|
||||||
}
|
}
|
||||||
@@ -800,7 +813,7 @@ mod tests {
|
|||||||
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
||||||
|
|
||||||
// Push packets with consistent timing
|
// Push packets with consistent timing
|
||||||
for i in 0u16..20 {
|
for i in 0u32..20 {
|
||||||
let pkt = make_packet(i);
|
let pkt = make_packet(i);
|
||||||
let arrival_ms = i as u64 * 20;
|
let arrival_ms = i as u64 * 20;
|
||||||
jb.push_with_arrival(pkt, arrival_ms);
|
jb.push_with_arrival(pkt, arrival_ms);
|
||||||
|
|||||||
@@ -14,22 +14,28 @@
|
|||||||
|
|
||||||
pub mod bandwidth;
|
pub mod bandwidth;
|
||||||
pub mod codec_id;
|
pub mod codec_id;
|
||||||
|
pub mod dred_tuner;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod jitter;
|
pub mod jitter;
|
||||||
|
pub mod media_type;
|
||||||
pub mod packet;
|
pub mod packet;
|
||||||
|
pub mod priority_mode;
|
||||||
pub mod quality;
|
pub mod quality;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
// Re-export key types at crate root for convenience.
|
// Re-export key types at crate root for convenience.
|
||||||
pub use codec_id::{CodecId, QualityProfile};
|
|
||||||
pub use error::*;
|
|
||||||
pub use packet::{
|
|
||||||
CallAcceptMode, HangupReason, MediaHeader, MediaPacket, MiniFrameContext, MiniHeader,
|
|
||||||
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL,
|
|
||||||
FRAME_TYPE_MINI,
|
|
||||||
};
|
|
||||||
pub use bandwidth::{BandwidthEstimator, CongestionState};
|
pub use bandwidth::{BandwidthEstimator, CongestionState};
|
||||||
|
pub use codec_id::{CodecId, QualityProfile};
|
||||||
|
pub use dred_tuner::{DredTuner, DredTuning};
|
||||||
|
pub use error::*;
|
||||||
|
pub use media_type::MediaType;
|
||||||
|
pub use packet::{
|
||||||
|
CallAcceptMode, FRAME_TYPE_FULL, FRAME_TYPE_MINI, HangupReason, MediaHeader, MediaHeaderV2,
|
||||||
|
MediaPacket, MiniFrameContext, MiniFrameContextV2, MiniHeader, MiniHeaderV2, PresenceUser,
|
||||||
|
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, default_signal_version,
|
||||||
|
};
|
||||||
|
pub use priority_mode::PriorityMode;
|
||||||
pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
|
pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
|
||||||
pub use session::{Session, SessionEvent, SessionState};
|
pub use session::{Session, SessionEvent, SessionState};
|
||||||
pub use traits::*;
|
pub use traits::*;
|
||||||
|
|||||||
57
crates/wzp-proto/src/media_type.rs
Normal file
57
crates/wzp-proto/src/media_type.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Media stream type carried in a v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum MediaType {
|
||||||
|
/// Encoded speech / music (Opus, Codec2, ComfortNoise).
|
||||||
|
Audio = 0,
|
||||||
|
/// Encoded video access unit (H.264, H.265, AV1; PRD-video-multicodec).
|
||||||
|
Video = 1,
|
||||||
|
/// Opaque payload not interpreted by the relay (reserved).
|
||||||
|
Data = 2,
|
||||||
|
/// In-band control message carried on the media plane (reserved).
|
||||||
|
Control = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MediaType {
|
||||||
|
/// Encode to the wire byte representation (`self as u8`).
|
||||||
|
pub const fn to_wire(self) -> u8 {
|
||||||
|
self as u8
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode from a wire byte. Returns `None` for values outside 0..=3.
|
||||||
|
pub const fn from_wire(v: u8) -> Option<Self> {
|
||||||
|
match v {
|
||||||
|
0 => Some(Self::Audio),
|
||||||
|
1 => Some(Self::Video),
|
||||||
|
2 => Some(Self::Data),
|
||||||
|
3 => Some(Self::Control),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn media_type_roundtrip() {
|
||||||
|
for mt in [
|
||||||
|
MediaType::Audio,
|
||||||
|
MediaType::Video,
|
||||||
|
MediaType::Data,
|
||||||
|
MediaType::Control,
|
||||||
|
] {
|
||||||
|
assert_eq!(MediaType::from_wire(mt.to_wire()), Some(mt));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn media_type_unknown_rejected() {
|
||||||
|
for v in 4u8..=255 {
|
||||||
|
assert!(MediaType::from_wire(v).is_none(), "v={v}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
34
crates/wzp-proto/src/priority_mode.rs
Normal file
34
crates/wzp-proto/src/priority_mode.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//! Priority mode for bandwidth allocation between audio and video.
|
||||||
|
//!
|
||||||
|
//! See `docs/PRD/PRD-video-quality-priority.md` for the full design.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Bandwidth-allocation policy between audio and video.
|
||||||
|
///
|
||||||
|
/// Carried on [`QualityProfile`](crate::QualityProfile) and mutable at
|
||||||
|
/// runtime via [`SignalMessage::SetPriorityMode`](crate::SignalMessage).
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||||
|
pub enum PriorityMode {
|
||||||
|
/// Audio gets its floor first; video gets the remainder.
|
||||||
|
/// Default for voice/video calls.
|
||||||
|
#[default]
|
||||||
|
AudioFirst,
|
||||||
|
/// Video gets its floor first; audio degrades to Opus 16k floor.
|
||||||
|
VideoFirst,
|
||||||
|
/// Audio clamped to 16 kbps (intelligible speech); video gets remainder.
|
||||||
|
/// Falls back to slide mode when bandwidth drops below SD floor.
|
||||||
|
ScreenShare,
|
||||||
|
/// Proportional split (~15 % audio, ~85 % video).
|
||||||
|
Balanced,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn priority_mode_default_is_audio_first() {
|
||||||
|
assert_eq!(PriorityMode::default(), PriorityMode::AudioFirst);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,24 +1,40 @@
|
|||||||
|
//! See also: [`crate::dred_tuner`] for continuous DRED tuning within a tier.
|
||||||
|
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use crate::BandwidthEstimator;
|
||||||
|
use crate::QualityProfile;
|
||||||
use crate::packet::QualityReport;
|
use crate::packet::QualityReport;
|
||||||
use crate::traits::QualityController;
|
use crate::traits::QualityController;
|
||||||
use crate::QualityProfile;
|
|
||||||
|
|
||||||
/// Network quality tier — drives codec and FEC selection.
|
/// Network quality tier — drives codec and FEC selection.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
///
|
||||||
|
/// 5-tier range from studio quality down to catastrophic:
|
||||||
|
/// Studio64k > Studio48k > Studio32k > Good > Degraded > Catastrophic
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum Tier {
|
pub enum Tier {
|
||||||
/// loss < 10%, RTT < 400ms
|
/// loss >= 15% OR RTT >= 200ms — Codec2 1.2k
|
||||||
Good,
|
Catastrophic = 0,
|
||||||
/// loss 10-40% OR RTT 400-600ms
|
/// loss < 15% AND RTT < 200ms — Opus 6k
|
||||||
Degraded,
|
Degraded = 1,
|
||||||
/// loss > 40% OR RTT > 600ms
|
/// loss < 5% AND RTT < 100ms — Opus 24k
|
||||||
Catastrophic,
|
Good = 2,
|
||||||
|
/// loss < 2% AND RTT < 80ms — Opus 32k
|
||||||
|
Studio32k = 3,
|
||||||
|
/// loss < 1% AND RTT < 50ms — Opus 48k
|
||||||
|
Studio48k = 4,
|
||||||
|
/// loss < 1% AND RTT < 30ms — Opus 64k
|
||||||
|
Studio64k = 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tier {
|
impl Tier {
|
||||||
pub fn profile(self) -> QualityProfile {
|
pub fn profile(self) -> QualityProfile {
|
||||||
match self {
|
match self {
|
||||||
|
Self::Studio64k => QualityProfile::STUDIO_64K,
|
||||||
|
Self::Studio48k => QualityProfile::STUDIO_48K,
|
||||||
|
Self::Studio32k => QualityProfile::STUDIO_32K,
|
||||||
Self::Good => QualityProfile::GOOD,
|
Self::Good => QualityProfile::GOOD,
|
||||||
Self::Degraded => QualityProfile::DEGRADED,
|
Self::Degraded => QualityProfile::DEGRADED,
|
||||||
Self::Catastrophic => QualityProfile::CATASTROPHIC,
|
Self::Catastrophic => QualityProfile::CATASTROPHIC,
|
||||||
@@ -39,7 +55,7 @@ impl Tier {
|
|||||||
NetworkContext::CellularLte
|
NetworkContext::CellularLte
|
||||||
| NetworkContext::Cellular5g
|
| NetworkContext::Cellular5g
|
||||||
| NetworkContext::Cellular3g => {
|
| NetworkContext::Cellular3g => {
|
||||||
// Tighter thresholds for cellular networks
|
// Tighter thresholds for cellular — no studio tiers
|
||||||
if loss > 25.0 || rtt > 500 {
|
if loss > 25.0 || rtt > 500 {
|
||||||
Self::Catastrophic
|
Self::Catastrophic
|
||||||
} else if loss > 8.0 || rtt > 300 {
|
} else if loss > 8.0 || rtt > 300 {
|
||||||
@@ -49,13 +65,18 @@ impl Tier {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
NetworkContext::WiFi | NetworkContext::Unknown => {
|
NetworkContext::WiFi | NetworkContext::Unknown => {
|
||||||
// Original thresholds
|
if loss >= 15.0 || rtt >= 200 {
|
||||||
if loss > 40.0 || rtt > 600 {
|
|
||||||
Self::Catastrophic
|
Self::Catastrophic
|
||||||
} else if loss > 10.0 || rtt > 400 {
|
} else if loss >= 5.0 || rtt >= 100 {
|
||||||
Self::Degraded
|
Self::Degraded
|
||||||
} else {
|
} else if loss >= 2.0 || rtt >= 80 {
|
||||||
Self::Good
|
Self::Good
|
||||||
|
} else if loss >= 1.0 || rtt >= 50 {
|
||||||
|
Self::Studio32k
|
||||||
|
} else if rtt >= 30 {
|
||||||
|
Self::Studio48k
|
||||||
|
} else {
|
||||||
|
Self::Studio64k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -64,29 +85,32 @@ impl Tier {
|
|||||||
/// Return the next lower (worse) tier, or None if already at the worst.
|
/// Return the next lower (worse) tier, or None if already at the worst.
|
||||||
pub fn downgrade(self) -> Option<Tier> {
|
pub fn downgrade(self) -> Option<Tier> {
|
||||||
match self {
|
match self {
|
||||||
|
Self::Studio64k => Some(Self::Studio48k),
|
||||||
|
Self::Studio48k => Some(Self::Studio32k),
|
||||||
|
Self::Studio32k => Some(Self::Good),
|
||||||
Self::Good => Some(Self::Degraded),
|
Self::Good => Some(Self::Degraded),
|
||||||
Self::Degraded => Some(Self::Catastrophic),
|
Self::Degraded => Some(Self::Catastrophic),
|
||||||
Self::Catastrophic => None,
|
Self::Catastrophic => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether this is a studio tier (above Good).
|
||||||
|
pub fn is_studio(self) -> bool {
|
||||||
|
matches!(self, Self::Studio64k | Self::Studio48k | Self::Studio32k)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Describes the network transport type for context-aware quality decisions.
|
/// Describes the network transport type for context-aware quality decisions.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||||
pub enum NetworkContext {
|
pub enum NetworkContext {
|
||||||
WiFi,
|
WiFi,
|
||||||
CellularLte,
|
CellularLte,
|
||||||
Cellular5g,
|
Cellular5g,
|
||||||
Cellular3g,
|
Cellular3g,
|
||||||
|
#[default]
|
||||||
Unknown,
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for NetworkContext {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Unknown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Adaptive quality controller with hysteresis to prevent tier flapping.
|
/// Adaptive quality controller with hysteresis to prevent tier flapping.
|
||||||
///
|
///
|
||||||
/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular)
|
/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular)
|
||||||
@@ -108,20 +132,50 @@ pub struct AdaptiveQualityController {
|
|||||||
fec_boost_until: Option<Instant>,
|
fec_boost_until: Option<Instant>,
|
||||||
/// FEC boost amount to add during handoff recovery window.
|
/// FEC boost amount to add during handoff recovery window.
|
||||||
fec_boost_amount: f32,
|
fec_boost_amount: f32,
|
||||||
|
/// Probing state: when Some, we're actively testing a higher tier.
|
||||||
|
probe: Option<ProbeState>,
|
||||||
|
/// Time spent stable at the current tier (for probe trigger).
|
||||||
|
stable_since: Option<Instant>,
|
||||||
|
/// Optional bandwidth estimator for BWE-guarded upgrades.
|
||||||
|
bwe: Option<Arc<BandwidthEstimator>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Threshold for downgrading (fast reaction to degradation).
|
/// Threshold for downgrading (fast reaction to degradation).
|
||||||
const DOWNGRADE_THRESHOLD: u32 = 3;
|
const DOWNGRADE_THRESHOLD: u32 = 3;
|
||||||
/// Threshold for downgrading on cellular networks (even faster).
|
/// Threshold for downgrading on cellular networks (even faster).
|
||||||
const CELLULAR_DOWNGRADE_THRESHOLD: u32 = 2;
|
const CELLULAR_DOWNGRADE_THRESHOLD: u32 = 2;
|
||||||
/// Threshold for upgrading (slow, cautious improvement).
|
/// Threshold for upgrading from Catastrophic/Degraded to Good.
|
||||||
const UPGRADE_THRESHOLD: u32 = 10;
|
const UPGRADE_THRESHOLD: u32 = 5;
|
||||||
|
/// Threshold for upgrading into studio tiers (very conservative).
|
||||||
|
const STUDIO_UPGRADE_THRESHOLD: u32 = 10;
|
||||||
/// Maximum history window size.
|
/// Maximum history window size.
|
||||||
const HISTORY_SIZE: usize = 20;
|
const HISTORY_SIZE: usize = 20;
|
||||||
/// Default FEC boost amount during handoff recovery.
|
/// Default FEC boost amount during handoff recovery.
|
||||||
const DEFAULT_FEC_BOOST: f32 = 0.2;
|
const DEFAULT_FEC_BOOST: f32 = 0.2;
|
||||||
/// Duration of FEC boost after a network handoff.
|
/// Duration of FEC boost after a network handoff.
|
||||||
const FEC_BOOST_DURATION_SECS: u64 = 10;
|
const FEC_BOOST_DURATION_SECS: u64 = 10;
|
||||||
|
/// Minimum time stable at current tier before probing upward (30 seconds).
|
||||||
|
const PROBE_STABLE_SECS: u64 = 30;
|
||||||
|
/// Duration of a probe window (5 seconds — ~25 quality reports at 1/s).
|
||||||
|
const PROBE_DURATION_SECS: u64 = 5;
|
||||||
|
/// Maximum bad reports during probe before aborting (1 out of ~5 = 20%).
|
||||||
|
const PROBE_MAX_BAD: u32 = 1;
|
||||||
|
/// Cooldown after a failed probe before trying again (60 seconds).
|
||||||
|
const PROBE_COOLDOWN_SECS: u64 = 60;
|
||||||
|
|
||||||
|
/// Active bandwidth probe state.
|
||||||
|
struct ProbeState {
|
||||||
|
/// The tier we're probing (one step above current).
|
||||||
|
target_tier: Tier,
|
||||||
|
/// Profile to apply during probe.
|
||||||
|
target_profile: QualityProfile,
|
||||||
|
/// When the probe started.
|
||||||
|
started: Instant,
|
||||||
|
/// Reports observed during probe.
|
||||||
|
probe_reports: u32,
|
||||||
|
/// Bad reports during probe (loss/RTT exceeded target tier thresholds).
|
||||||
|
bad_reports: u32,
|
||||||
|
}
|
||||||
|
|
||||||
impl AdaptiveQualityController {
|
impl AdaptiveQualityController {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
@@ -135,6 +189,9 @@ impl AdaptiveQualityController {
|
|||||||
network_context: NetworkContext::default(),
|
network_context: NetworkContext::default(),
|
||||||
fec_boost_until: None,
|
fec_boost_until: None,
|
||||||
fec_boost_amount: DEFAULT_FEC_BOOST,
|
fec_boost_amount: DEFAULT_FEC_BOOST,
|
||||||
|
probe: None,
|
||||||
|
stable_since: None,
|
||||||
|
bwe: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,6 +231,10 @@ impl AdaptiveQualityController {
|
|||||||
self.forced = false;
|
self.forced = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel any active probe
|
||||||
|
self.probe = None;
|
||||||
|
self.stable_since = None;
|
||||||
|
|
||||||
// Activate FEC boost for any network change
|
// Activate FEC boost for any network change
|
||||||
self.fec_boost_until = Some(Instant::now() + Duration::from_secs(FEC_BOOST_DURATION_SECS));
|
self.fec_boost_until = Some(Instant::now() + Duration::from_secs(FEC_BOOST_DURATION_SECS));
|
||||||
}
|
}
|
||||||
@@ -194,6 +255,19 @@ impl AdaptiveQualityController {
|
|||||||
pub fn reset_counters(&mut self) {
|
pub fn reset_counters(&mut self) {
|
||||||
self.consecutive_up = 0;
|
self.consecutive_up = 0;
|
||||||
self.consecutive_down = 0;
|
self.consecutive_down = 0;
|
||||||
|
self.probe = None;
|
||||||
|
self.stable_since = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attach a bandwidth estimator for BWE-guarded tier transitions.
|
||||||
|
pub fn set_bandwidth_estimator(&mut self, bwe: Arc<BandwidthEstimator>) {
|
||||||
|
self.bwe = Some(bwe);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the bitrate ceiling (in bps) for a given tier, including FEC overhead.
|
||||||
|
fn tier_ceiling_bps(tier: Tier) -> u64 {
|
||||||
|
let kbps = tier.profile().total_bitrate_kbps();
|
||||||
|
(kbps * 1000.0) as u64
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the effective downgrade threshold based on network context.
|
/// Get the effective downgrade threshold based on network context.
|
||||||
@@ -213,16 +287,13 @@ impl AdaptiveQualityController {
|
|||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_worse = match (self.current_tier, observed_tier) {
|
let is_worse = observed_tier < self.current_tier;
|
||||||
(Tier::Good, Tier::Degraded | Tier::Catastrophic) => true,
|
|
||||||
(Tier::Degraded, Tier::Catastrophic) => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_worse {
|
if is_worse {
|
||||||
self.consecutive_up = 0;
|
self.consecutive_up = 0;
|
||||||
self.consecutive_down += 1;
|
self.consecutive_down += 1;
|
||||||
if self.consecutive_down >= self.downgrade_threshold() {
|
if self.consecutive_down >= self.downgrade_threshold() {
|
||||||
|
// Jump directly to the observed tier (don't step one-at-a-time on downgrade)
|
||||||
self.current_tier = observed_tier;
|
self.current_tier = observed_tier;
|
||||||
self.current_profile = observed_tier.profile();
|
self.current_profile = observed_tier.profile();
|
||||||
self.consecutive_down = 0;
|
self.consecutive_down = 0;
|
||||||
@@ -232,22 +303,123 @@ impl AdaptiveQualityController {
|
|||||||
// Better conditions
|
// Better conditions
|
||||||
self.consecutive_down = 0;
|
self.consecutive_down = 0;
|
||||||
self.consecutive_up += 1;
|
self.consecutive_up += 1;
|
||||||
if self.consecutive_up >= UPGRADE_THRESHOLD {
|
// Studio tiers require more consecutive good reports
|
||||||
// Only upgrade one step at a time
|
let threshold = if self.current_tier >= Tier::Good {
|
||||||
let next_tier = match self.current_tier {
|
STUDIO_UPGRADE_THRESHOLD
|
||||||
Tier::Catastrophic => Tier::Degraded,
|
} else {
|
||||||
Tier::Degraded => Tier::Good,
|
UPGRADE_THRESHOLD
|
||||||
Tier::Good => return None,
|
|
||||||
};
|
};
|
||||||
|
if self.consecutive_up >= threshold {
|
||||||
|
// Only upgrade one step at a time
|
||||||
|
if let Some(next_tier) = self.upgrade_one_step() {
|
||||||
|
// BWE guard: require 130% headroom over target tier bitrate
|
||||||
|
if let Some(ref bwe) = self.bwe {
|
||||||
|
let required = (Self::tier_ceiling_bps(next_tier) * 130) / 100;
|
||||||
|
if bwe.target_send_bps() < required {
|
||||||
|
// Insufficient bandwidth — reset counter to prevent flapping
|
||||||
|
self.consecutive_up = 0;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
self.current_tier = next_tier;
|
self.current_tier = next_tier;
|
||||||
self.current_profile = next_tier.profile();
|
self.current_profile = next_tier.profile();
|
||||||
self.consecutive_up = 0;
|
self.consecutive_up = 0;
|
||||||
return Some(self.current_profile);
|
return Some(self.current_profile);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check whether to start, continue, or conclude a bandwidth probe.
|
||||||
|
///
|
||||||
|
/// Called from `observe()` when no hysteresis transition fired.
|
||||||
|
fn check_probe(&mut self, observed_tier: Tier) -> Option<QualityProfile> {
|
||||||
|
// Don't probe if forced, or if already at highest tier, or on cellular
|
||||||
|
if self.forced || self.current_tier == Tier::Studio64k {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if matches!(
|
||||||
|
self.network_context,
|
||||||
|
NetworkContext::CellularLte | NetworkContext::Cellular5g | NetworkContext::Cellular3g
|
||||||
|
) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have an active probe, evaluate it
|
||||||
|
if let Some(ref mut probe) = self.probe {
|
||||||
|
probe.probe_reports += 1;
|
||||||
|
|
||||||
|
// Check if the observed tier meets the probe target
|
||||||
|
if observed_tier < probe.target_tier {
|
||||||
|
probe.bad_reports += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Probe failed: too many bad reports
|
||||||
|
if probe.bad_reports > PROBE_MAX_BAD {
|
||||||
|
let _failed_probe = self.probe.take();
|
||||||
|
// Reset stable_since to trigger cooldown
|
||||||
|
self.stable_since = Some(Instant::now() + Duration::from_secs(PROBE_COOLDOWN_SECS));
|
||||||
|
return None; // stay at current tier
|
||||||
|
}
|
||||||
|
|
||||||
|
// Probe succeeded: enough good reports within the window
|
||||||
|
if probe.started.elapsed() >= Duration::from_secs(PROBE_DURATION_SECS) {
|
||||||
|
let target = probe.target_tier;
|
||||||
|
let profile = probe.target_profile;
|
||||||
|
self.probe.take();
|
||||||
|
self.current_tier = target;
|
||||||
|
self.current_profile = profile;
|
||||||
|
self.consecutive_up = 0;
|
||||||
|
self.stable_since = Some(Instant::now());
|
||||||
|
return Some(profile);
|
||||||
|
}
|
||||||
|
|
||||||
|
return None; // probe still running
|
||||||
|
}
|
||||||
|
|
||||||
|
// No active probe — check if we should start one
|
||||||
|
if observed_tier >= self.current_tier {
|
||||||
|
// Track stability
|
||||||
|
if self.stable_since.is_none() {
|
||||||
|
self.stable_since = Some(Instant::now());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(stable_since) = self.stable_since {
|
||||||
|
if stable_since.elapsed() >= Duration::from_secs(PROBE_STABLE_SECS) {
|
||||||
|
// Stable long enough — start probing
|
||||||
|
if let Some(next) = self.upgrade_one_step() {
|
||||||
|
self.probe = Some(ProbeState {
|
||||||
|
target_tier: next,
|
||||||
|
target_profile: next.profile(),
|
||||||
|
started: Instant::now(),
|
||||||
|
probe_reports: 0,
|
||||||
|
bad_reports: 0,
|
||||||
|
});
|
||||||
|
// Return the probe profile so the encoder switches
|
||||||
|
return Some(next.profile());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Conditions degraded — reset stability timer
|
||||||
|
self.stable_since = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upgrade_one_step(&self) -> Option<Tier> {
|
||||||
|
match self.current_tier {
|
||||||
|
Tier::Catastrophic => Some(Tier::Degraded),
|
||||||
|
Tier::Degraded => Some(Tier::Good),
|
||||||
|
Tier::Good => Some(Tier::Studio32k),
|
||||||
|
Tier::Studio32k => Some(Tier::Studio48k),
|
||||||
|
Tier::Studio48k => Some(Tier::Studio64k),
|
||||||
|
Tier::Studio64k => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AdaptiveQualityController {
|
impl Default for AdaptiveQualityController {
|
||||||
@@ -269,7 +441,17 @@ impl QualityController for AdaptiveQualityController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let observed = Tier::classify_with_context(report, self.network_context);
|
let observed = Tier::classify_with_context(report, self.network_context);
|
||||||
self.try_transition(observed)
|
|
||||||
|
// First check for downgrades/upgrades via hysteresis
|
||||||
|
if let Some(profile) = self.try_transition(observed) {
|
||||||
|
// Cancel any active probe on tier change
|
||||||
|
self.probe.take();
|
||||||
|
self.stable_since = None;
|
||||||
|
return Some(profile);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then check probing
|
||||||
|
self.check_probe(observed)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn force_profile(&mut self, profile: QualityProfile) {
|
fn force_profile(&mut self, profile: QualityProfile) {
|
||||||
@@ -331,25 +513,33 @@ mod tests {
|
|||||||
}
|
}
|
||||||
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||||
|
|
||||||
// 9 good reports — not enough
|
// 4 good reports — not enough (threshold is 5)
|
||||||
let good = make_report(2.0, 100);
|
let good = make_report(0.5, 20); // studio-quality report
|
||||||
for _ in 0..9 {
|
for _ in 0..4 {
|
||||||
assert!(ctrl.observe(&good).is_none());
|
assert!(ctrl.observe(&good).is_none());
|
||||||
}
|
}
|
||||||
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||||
|
|
||||||
// 10th good report triggers upgrade (one step: Catastrophic → Degraded)
|
// 5th good report triggers upgrade (one step: Catastrophic → Degraded)
|
||||||
let result = ctrl.observe(&good);
|
let result = ctrl.observe(&good);
|
||||||
assert!(result.is_some());
|
assert!(result.is_some());
|
||||||
assert_eq!(ctrl.tier(), Tier::Degraded);
|
assert_eq!(ctrl.tier(), Tier::Degraded);
|
||||||
|
|
||||||
// Need another 10 to go from Degraded → Good
|
// Another 5 to go from Degraded → Good
|
||||||
for _ in 0..9 {
|
for _ in 0..4 {
|
||||||
assert!(ctrl.observe(&good).is_none());
|
assert!(ctrl.observe(&good).is_none());
|
||||||
}
|
}
|
||||||
let result = ctrl.observe(&good);
|
let result = ctrl.observe(&good);
|
||||||
assert!(result.is_some());
|
assert!(result.is_some());
|
||||||
assert_eq!(ctrl.tier(), Tier::Good);
|
assert_eq!(ctrl.tier(), Tier::Good);
|
||||||
|
|
||||||
|
// Studio upgrades need 10 consecutive — Good → Studio32k
|
||||||
|
for _ in 0..9 {
|
||||||
|
assert!(ctrl.observe(&good).is_none());
|
||||||
|
}
|
||||||
|
let result = ctrl.observe(&good);
|
||||||
|
assert!(result.is_some());
|
||||||
|
assert_eq!(ctrl.tier(), Tier::Studio32k);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -364,13 +554,78 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bwe_guard_blocks_upgrade_when_bandwidth_insufficient() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
|
||||||
|
// Force to catastrophic
|
||||||
|
let bad = make_report(50.0, 300);
|
||||||
|
for _ in 0..3 {
|
||||||
|
ctrl.observe(&bad);
|
||||||
|
}
|
||||||
|
assert_eq!(ctrl.tier(), Tier::Catastrophic);
|
||||||
|
|
||||||
|
// Attach a BWE with very low headroom.
|
||||||
|
// Degraded tier needs 6kbps * 1.5 FEC = 9kbps → 130% = 11.7kbps.
|
||||||
|
// Set target_send_bps ≈ 9_000 (below 11_700 threshold).
|
||||||
|
let bwe = Arc::new(BandwidthEstimator::new(1000.0, 1.0, 100_000.0));
|
||||||
|
bwe.update_from_path(1_000_000, 0, 10); // high cwnd
|
||||||
|
bwe.update_from_peer(10_000); // low remb → target = 9_000
|
||||||
|
ctrl.set_bandwidth_estimator(bwe.clone());
|
||||||
|
|
||||||
|
let good = make_report(0.5, 20);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(
|
||||||
|
ctrl.observe(&good).is_none(),
|
||||||
|
"upgrade should be blocked by low BWE"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
assert_eq!(
|
||||||
|
ctrl.tier(),
|
||||||
|
Tier::Catastrophic,
|
||||||
|
"should remain at Catastrophic"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Raise BWE well above the 130% threshold
|
||||||
|
bwe.update_from_peer(100_000); // target ≈ 90_000 bps
|
||||||
|
|
||||||
|
// Counter was reset, need another 5 good reports
|
||||||
|
for _ in 0..4 {
|
||||||
|
assert!(ctrl.observe(&good).is_none());
|
||||||
|
}
|
||||||
|
let result = ctrl.observe(&good);
|
||||||
|
assert!(
|
||||||
|
result.is_some(),
|
||||||
|
"upgrade should proceed with sufficient BWE"
|
||||||
|
);
|
||||||
|
assert_eq!(ctrl.tier(), Tier::Degraded);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tier_classification() {
|
fn tier_classification() {
|
||||||
assert_eq!(Tier::classify(&make_report(5.0, 200)), Tier::Good);
|
// Studio tiers
|
||||||
assert_eq!(Tier::classify(&make_report(15.0, 200)), Tier::Degraded);
|
assert_eq!(Tier::classify(&make_report(0.5, 20)), Tier::Studio64k);
|
||||||
assert_eq!(Tier::classify(&make_report(5.0, 500)), Tier::Degraded);
|
assert_eq!(Tier::classify(&make_report(0.5, 40)), Tier::Studio48k);
|
||||||
assert_eq!(Tier::classify(&make_report(50.0, 200)), Tier::Catastrophic);
|
assert_eq!(Tier::classify(&make_report(1.5, 60)), Tier::Studio32k);
|
||||||
assert_eq!(Tier::classify(&make_report(5.0, 700)), Tier::Catastrophic);
|
// Good/Degraded/Catastrophic
|
||||||
|
assert_eq!(Tier::classify(&make_report(3.0, 90)), Tier::Good);
|
||||||
|
assert_eq!(Tier::classify(&make_report(6.0, 120)), Tier::Degraded);
|
||||||
|
assert_eq!(Tier::classify(&make_report(16.0, 120)), Tier::Catastrophic);
|
||||||
|
assert_eq!(Tier::classify(&make_report(5.0, 200)), Tier::Catastrophic);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn studio_tier_boundaries() {
|
||||||
|
// loss < 1% AND RTT < 30ms → Studio64k
|
||||||
|
assert_eq!(Tier::classify(&make_report(0.9, 28)), Tier::Studio64k);
|
||||||
|
// loss < 1% AND RTT 30-49ms → Studio48k
|
||||||
|
assert_eq!(Tier::classify(&make_report(0.9, 32)), Tier::Studio48k);
|
||||||
|
// loss < 2% AND RTT < 80ms → Studio32k (but loss >= 1%)
|
||||||
|
assert_eq!(Tier::classify(&make_report(1.5, 40)), Tier::Studio32k);
|
||||||
|
// loss >= 2% → Good (use 2.5 to survive u8 quantization)
|
||||||
|
assert_eq!(Tier::classify(&make_report(2.5, 40)), Tier::Good);
|
||||||
|
// RTT 80ms → Good
|
||||||
|
assert_eq!(Tier::classify(&make_report(0.5, 80)), Tier::Good);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------
|
// ---------------------------------------------------------------
|
||||||
@@ -379,8 +634,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cellular_tighter_thresholds() {
|
fn cellular_tighter_thresholds() {
|
||||||
// 12% loss: Good on WiFi, Degraded on cellular
|
// 9% loss: Degraded on both WiFi (>=5%) and cellular (>=8%)
|
||||||
let report = make_report(12.0, 200);
|
let report = make_report(9.0, 80);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||||
Tier::Degraded
|
Tier::Degraded
|
||||||
@@ -390,22 +645,22 @@ mod tests {
|
|||||||
Tier::Degraded
|
Tier::Degraded
|
||||||
);
|
);
|
||||||
|
|
||||||
// 9% loss: Good on WiFi, Degraded on cellular
|
// 6% loss, low RTT: Degraded on WiFi (>=5%), Good on cellular (<8%)
|
||||||
let report = make_report(9.0, 200);
|
let report = make_report(6.0, 80);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||||
|
Tier::Degraded
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||||
Tier::Good
|
Tier::Good
|
||||||
);
|
);
|
||||||
assert_eq!(
|
|
||||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
|
||||||
Tier::Degraded
|
|
||||||
);
|
|
||||||
|
|
||||||
// 30% loss: Degraded on WiFi, Catastrophic on cellular
|
// 30% loss: Catastrophic on WiFi (>=15%), Catastrophic on cellular (>=25%)
|
||||||
let report = make_report(30.0, 200);
|
let report = make_report(30.0, 80);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||||
Tier::Degraded
|
Tier::Catastrophic
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::Cellular3g),
|
Tier::classify_with_context(&report, NetworkContext::Cellular3g),
|
||||||
@@ -415,15 +670,29 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cellular_rtt_thresholds() {
|
fn cellular_rtt_thresholds() {
|
||||||
// RTT 350ms: Good on WiFi, Degraded on cellular
|
// RTT 150ms: Degraded on WiFi (>=100ms), Good on cellular (<300ms and loss<8%)
|
||||||
let report = make_report(2.0, 348); // rtt_4ms rounds so use 348
|
let report = make_report(2.0, 148);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||||
Tier::Good
|
Tier::Degraded
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||||
Tier::Degraded
|
Tier::Good
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cellular_no_studio_tiers() {
|
||||||
|
// Even with perfect network, cellular stays at Good (no studio)
|
||||||
|
let report = make_report(0.0, 10);
|
||||||
|
assert_eq!(
|
||||||
|
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||||
|
Tier::Good
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||||
|
Tier::Studio64k
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -469,6 +738,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tier_downgrade() {
|
fn tier_downgrade() {
|
||||||
|
assert_eq!(Tier::Studio64k.downgrade(), Some(Tier::Studio48k));
|
||||||
|
assert_eq!(Tier::Studio48k.downgrade(), Some(Tier::Studio32k));
|
||||||
|
assert_eq!(Tier::Studio32k.downgrade(), Some(Tier::Good));
|
||||||
assert_eq!(Tier::Good.downgrade(), Some(Tier::Degraded));
|
assert_eq!(Tier::Good.downgrade(), Some(Tier::Degraded));
|
||||||
assert_eq!(Tier::Degraded.downgrade(), Some(Tier::Catastrophic));
|
assert_eq!(Tier::Degraded.downgrade(), Some(Tier::Catastrophic));
|
||||||
assert_eq!(Tier::Catastrophic.downgrade(), None);
|
assert_eq!(Tier::Catastrophic.downgrade(), None);
|
||||||
@@ -478,4 +750,103 @@ mod tests {
|
|||||||
fn network_context_default() {
|
fn network_context_default() {
|
||||||
assert_eq!(NetworkContext::default(), NetworkContext::Unknown);
|
assert_eq!(NetworkContext::default(), NetworkContext::Unknown);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
// Bandwidth probing tests
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn probe_triggers_after_stable_period() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
let excellent = make_report(0.3, 20); // would classify as Studio64k
|
||||||
|
|
||||||
|
// Starts at Good. Fast-forward stability by setting stable_since directly.
|
||||||
|
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||||
|
|
||||||
|
// One excellent report should trigger a probe (Good → Studio32k)
|
||||||
|
let result = ctrl.observe(&excellent);
|
||||||
|
assert!(result.is_some(), "should start probe after 30s stable");
|
||||||
|
assert!(ctrl.probe.is_some(), "probe should be active");
|
||||||
|
assert_eq!(ctrl.probe.as_ref().unwrap().target_tier, Tier::Studio32k);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn probe_succeeds_after_window() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||||
|
|
||||||
|
let excellent = make_report(0.3, 20);
|
||||||
|
|
||||||
|
// Trigger probe start
|
||||||
|
let result = ctrl.observe(&excellent);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
// Simulate probe window elapsed by backdating started
|
||||||
|
ctrl.probe.as_mut().unwrap().started =
|
||||||
|
Instant::now() - Duration::from_secs(PROBE_DURATION_SECS);
|
||||||
|
|
||||||
|
// Next good report should finalize the probe
|
||||||
|
let result = ctrl.observe(&excellent);
|
||||||
|
assert!(result.is_some(), "probe should succeed");
|
||||||
|
assert_eq!(ctrl.current_tier, Tier::Studio32k);
|
||||||
|
assert!(ctrl.probe.is_none(), "probe should be cleared");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn probe_fails_on_bad_reports() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
// Put controller at Studio32k, pretend we've been stable
|
||||||
|
ctrl.current_tier = Tier::Studio32k;
|
||||||
|
ctrl.current_profile = Tier::Studio32k.profile();
|
||||||
|
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(31));
|
||||||
|
|
||||||
|
// Start a probe to Studio48k
|
||||||
|
let excellent = make_report(0.3, 20);
|
||||||
|
let result = ctrl.observe(&excellent);
|
||||||
|
assert!(result.is_some()); // probe started
|
||||||
|
assert_eq!(ctrl.probe.as_ref().unwrap().target_tier, Tier::Studio48k);
|
||||||
|
|
||||||
|
// Feed bad reports (loss too high for Studio48k)
|
||||||
|
let degraded = make_report(3.0, 100);
|
||||||
|
ctrl.observe(°raded); // first bad
|
||||||
|
ctrl.observe(°raded); // second bad — exceeds PROBE_MAX_BAD (1)
|
||||||
|
|
||||||
|
// Probe should be cancelled
|
||||||
|
assert!(
|
||||||
|
ctrl.probe.is_none(),
|
||||||
|
"probe should be cancelled after bad reports"
|
||||||
|
);
|
||||||
|
// Should still be at Studio32k (not upgraded)
|
||||||
|
assert_eq!(ctrl.current_tier, Tier::Studio32k);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_probe_on_cellular() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
ctrl.signal_network_change(NetworkContext::CellularLte);
|
||||||
|
ctrl.current_tier = Tier::Good;
|
||||||
|
ctrl.current_profile = Tier::Good.profile();
|
||||||
|
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(60));
|
||||||
|
|
||||||
|
let good = make_report(0.5, 40);
|
||||||
|
let result = ctrl.observe(&good);
|
||||||
|
// Should NOT probe on cellular
|
||||||
|
assert!(ctrl.probe.is_none(), "should not probe on cellular");
|
||||||
|
assert!(result.is_none() || ctrl.current_tier == Tier::Good);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_probe_at_highest_tier() {
|
||||||
|
let mut ctrl = AdaptiveQualityController::new();
|
||||||
|
ctrl.current_tier = Tier::Studio64k;
|
||||||
|
ctrl.current_profile = Tier::Studio64k.profile();
|
||||||
|
ctrl.stable_since = Some(Instant::now() - Duration::from_secs(60));
|
||||||
|
|
||||||
|
let excellent = make_report(0.1, 10);
|
||||||
|
let result = ctrl.observe(&excellent);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"should not probe when already at Studio64k"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,13 @@ pub trait AudioEncoder: Send + Sync {
|
|||||||
|
|
||||||
/// Enable/disable DTX (discontinuous transmission). No-op for Codec2.
|
/// Enable/disable DTX (discontinuous transmission). No-op for Codec2.
|
||||||
fn set_dtx(&mut self, _enabled: bool) {}
|
fn set_dtx(&mut self, _enabled: bool) {}
|
||||||
|
|
||||||
|
/// Hint the encoder about expected packet loss (0–100). In DRED mode the
|
||||||
|
/// encoder floors this at 15% internally. No-op for Codec2.
|
||||||
|
fn set_expected_loss(&mut self, _loss_pct: u8) {}
|
||||||
|
|
||||||
|
/// Set DRED duration in 10 ms frame units (0–104). No-op for Codec2.
|
||||||
|
fn set_dred_duration(&mut self, _frames: u8) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decodes compressed frames back to PCM audio.
|
/// Decodes compressed frames back to PCM audio.
|
||||||
@@ -54,18 +61,34 @@ pub trait FecEncoder: Send + Sync {
|
|||||||
/// Add a source symbol (one audio frame) to the current block.
|
/// Add a source symbol (one audio frame) to the current block.
|
||||||
fn add_source_symbol(&mut self, data: &[u8]) -> Result<(), FecError>;
|
fn add_source_symbol(&mut self, data: &[u8]) -> Result<(), FecError>;
|
||||||
|
|
||||||
|
/// Add a source symbol and mark whether it belongs to a keyframe.
|
||||||
|
///
|
||||||
|
/// When the block contains at least one keyframe source symbol,
|
||||||
|
/// [`generate_repair`] uses the configured keyframe ratio instead of the
|
||||||
|
/// nominal ratio.
|
||||||
|
///
|
||||||
|
/// Default implementation delegates to [`add_source_symbol`] and ignores
|
||||||
|
/// the keyframe flag.
|
||||||
|
fn add_source_symbol_with_keyframe(
|
||||||
|
&mut self,
|
||||||
|
data: &[u8],
|
||||||
|
_is_keyframe: bool,
|
||||||
|
) -> Result<(), FecError> {
|
||||||
|
self.add_source_symbol(data)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate repair symbols for the current block.
|
/// Generate repair symbols for the current block.
|
||||||
///
|
///
|
||||||
/// `ratio` is the repair overhead (e.g., 0.5 = 50% more symbols than source).
|
/// `ratio` is the repair overhead (e.g., 0.5 = 50% more symbols than source).
|
||||||
/// Returns `(fec_symbol_index, repair_data)` pairs.
|
/// Returns `(fec_symbol_index, repair_data)` pairs.
|
||||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError>;
|
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError>;
|
||||||
|
|
||||||
/// Finalize the current block and start a new one.
|
/// Finalize the current block and start a new one.
|
||||||
/// Returns the block ID of the finalized block.
|
/// Returns the block ID of the finalized block.
|
||||||
fn finalize_block(&mut self) -> Result<u8, FecError>;
|
fn finalize_block(&mut self) -> Result<u16, FecError>;
|
||||||
|
|
||||||
/// Current block ID being built.
|
/// Current block ID being built.
|
||||||
fn current_block_id(&self) -> u8;
|
fn current_block_id(&self) -> u16;
|
||||||
|
|
||||||
/// Number of source symbols in the current block.
|
/// Number of source symbols in the current block.
|
||||||
fn current_block_size(&self) -> usize;
|
fn current_block_size(&self) -> usize;
|
||||||
@@ -76,8 +99,8 @@ pub trait FecDecoder: Send + Sync {
|
|||||||
/// Feed a received symbol (source or repair) into the decoder.
|
/// Feed a received symbol (source or repair) into the decoder.
|
||||||
fn add_symbol(
|
fn add_symbol(
|
||||||
&mut self,
|
&mut self,
|
||||||
block_id: u8,
|
block_id: u16,
|
||||||
symbol_index: u8,
|
symbol_index: u16,
|
||||||
is_repair: bool,
|
is_repair: bool,
|
||||||
data: &[u8],
|
data: &[u8],
|
||||||
) -> Result<(), FecError>;
|
) -> Result<(), FecError>;
|
||||||
@@ -86,10 +109,10 @@ pub trait FecDecoder: Send + Sync {
|
|||||||
///
|
///
|
||||||
/// Returns `None` if not yet decodable (insufficient symbols).
|
/// Returns `None` if not yet decodable (insufficient symbols).
|
||||||
/// Returns `Some(Vec<source_frames>)` on success.
|
/// Returns `Some(Vec<source_frames>)` on success.
|
||||||
fn try_decode(&mut self, block_id: u8) -> Result<Option<Vec<Vec<u8>>>, FecError>;
|
fn try_decode(&mut self, block_id: u16) -> Result<Option<Vec<Vec<u8>>>, FecError>;
|
||||||
|
|
||||||
/// Drop state for blocks older than `block_id`.
|
/// Drop state for blocks older than `block_id`.
|
||||||
fn expire_before(&mut self, block_id: u8);
|
fn expire_before(&mut self, block_id: u16);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Crypto Traits ───────────────────────────────────────────────────────────
|
// ─── Crypto Traits ───────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ bytes = { workspace = true }
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
reqwest = { version = "0.12", features = ["json"] }
|
reqwest = { version = "0.12", features = ["json"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||||
@@ -28,6 +29,7 @@ prometheus = "0.13"
|
|||||||
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "ws"] }
|
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "ws"] }
|
||||||
tower-http = { version = "0.6", features = ["fs"] }
|
tower-http = { version = "0.6", features = ["fs"] }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
dashmap = "6"
|
||||||
dirs = "6"
|
dirs = "6"
|
||||||
sha2 = { workspace = true }
|
sha2 = { workspace = true }
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ fn main() {
|
|||||||
.output();
|
.output();
|
||||||
|
|
||||||
let hash = match output {
|
let hash = match output {
|
||||||
Ok(o) if o.status.success() => {
|
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout).trim().to_string(),
|
||||||
String::from_utf8_lossy(&o.stdout).trim().to_string()
|
|
||||||
}
|
|
||||||
_ => "unknown".to_string(),
|
_ => "unknown".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
467
crates/wzp-relay/src/audio_scorer.rs
Normal file
467
crates/wzp-relay/src/audio_scorer.rs
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
//! Tier F audio scorer — behavioural entropy detection for abuse mitigation.
|
||||||
|
//!
|
||||||
|
//! Computes a `legitimacy ∈ [0, 1]` score over a 10–30 s observation window.
|
||||||
|
//! Features: IAT CoV, payload-size bimodality, silence fraction, bitrate
|
||||||
|
//! deviation, and Q-flag cadence.
|
||||||
|
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use wzp_proto::{CodecId, MediaHeader, MediaType};
|
||||||
|
|
||||||
|
use crate::verdict::Verdict;
|
||||||
|
|
||||||
|
/// Maximum samples kept in rolling windows.
|
||||||
|
const MAX_IAT_SAMPLES: usize = 200;
|
||||||
|
const MAX_SIZE_SAMPLES: usize = 200;
|
||||||
|
const MAX_Q_INTERVALS: usize = 32;
|
||||||
|
|
||||||
|
/// Silence threshold: payload below this many bytes is treated as silence / CN.
|
||||||
|
const SILENCE_SIZE_THRESHOLD: usize = 16;
|
||||||
|
|
||||||
|
/// Observation window for bitrate tracking.
|
||||||
|
const BITRATE_WINDOW_SECS: u64 = 30;
|
||||||
|
|
||||||
|
// Number of payload-size histogram bins.
|
||||||
|
// (SIZE_BINS reserved for future histogram-based bimodality)
|
||||||
|
|
||||||
|
/// Audio-specific behavioural scorer (Tier F).
|
||||||
|
pub struct AudioScorer {
|
||||||
|
/// Rolling inter-arrival times.
|
||||||
|
iat_samples: VecDeque<Duration>,
|
||||||
|
last_arrival: Option<Instant>,
|
||||||
|
|
||||||
|
/// Rolling payload sizes.
|
||||||
|
size_samples: VecDeque<usize>,
|
||||||
|
|
||||||
|
/// Count of packets below silence threshold.
|
||||||
|
silence_packets: u32,
|
||||||
|
/// Total packets observed in current window.
|
||||||
|
total_packets: u32,
|
||||||
|
|
||||||
|
/// Bitrate window.
|
||||||
|
window_start: Instant,
|
||||||
|
window_bytes: u64,
|
||||||
|
|
||||||
|
/// Q-flag arrival intervals.
|
||||||
|
q_intervals: VecDeque<Duration>,
|
||||||
|
last_q_flag: Option<Instant>,
|
||||||
|
|
||||||
|
/// Codec declared at first packet (used for nominal bitrate baseline).
|
||||||
|
declared_codec: Option<CodecId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioScorer {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
iat_samples: VecDeque::with_capacity(MAX_IAT_SAMPLES),
|
||||||
|
last_arrival: None,
|
||||||
|
size_samples: VecDeque::with_capacity(MAX_SIZE_SAMPLES),
|
||||||
|
silence_packets: 0,
|
||||||
|
total_packets: 0,
|
||||||
|
window_start: Instant::now(),
|
||||||
|
window_bytes: 0,
|
||||||
|
q_intervals: VecDeque::with_capacity(MAX_Q_INTERVALS),
|
||||||
|
last_q_flag: None,
|
||||||
|
declared_codec: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feed one packet into the scorer.
|
||||||
|
pub fn observe(&mut self, header: &MediaHeader, payload_len: usize, now: Instant) {
|
||||||
|
// Ignore non-audio traffic.
|
||||||
|
if header.media_type != MediaType::Audio {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.declared_codec.is_none() {
|
||||||
|
self.declared_codec = Some(header.codec_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IAT
|
||||||
|
if let Some(last) = self.last_arrival {
|
||||||
|
let iat = now.saturating_duration_since(last);
|
||||||
|
self.iat_samples.push_back(iat);
|
||||||
|
if self.iat_samples.len() > MAX_IAT_SAMPLES {
|
||||||
|
self.iat_samples.pop_front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.last_arrival = Some(now);
|
||||||
|
|
||||||
|
// Payload size
|
||||||
|
self.size_samples.push_back(payload_len);
|
||||||
|
if self.size_samples.len() > MAX_SIZE_SAMPLES {
|
||||||
|
self.size_samples.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Silence fraction
|
||||||
|
self.total_packets += 1;
|
||||||
|
if payload_len <= SILENCE_SIZE_THRESHOLD {
|
||||||
|
self.silence_packets += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bitrate window
|
||||||
|
if now.duration_since(self.window_start) >= Duration::from_secs(BITRATE_WINDOW_SECS) {
|
||||||
|
self.window_start = now;
|
||||||
|
self.window_bytes = 0;
|
||||||
|
}
|
||||||
|
self.window_bytes += (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||||
|
|
||||||
|
// Q-flag cadence
|
||||||
|
if header.has_quality() {
|
||||||
|
if let Some(last) = self.last_q_flag {
|
||||||
|
let interval = now.saturating_duration_since(last);
|
||||||
|
self.q_intervals.push_back(interval);
|
||||||
|
if self.q_intervals.len() > MAX_Q_INTERVALS {
|
||||||
|
self.q_intervals.pop_front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.last_q_flag = Some(now);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute legitimacy score ∈ [0, 1].
|
||||||
|
///
|
||||||
|
/// Higher = more legitimate. Returns `None` when insufficient samples
|
||||||
|
/// have been collected (< 20 packets).
|
||||||
|
pub fn legitimacy(&self) -> Option<f32> {
|
||||||
|
if self.total_packets < 20 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut score = 1.0f32;
|
||||||
|
|
||||||
|
// 1. IAT CoV penalty
|
||||||
|
if let Some(cov) = self.iat_cov() {
|
||||||
|
if cov > 0.4 {
|
||||||
|
let penalty = ((cov - 0.4) / 0.6).min(1.0) * 0.25;
|
||||||
|
score -= penalty as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Silence fraction penalty
|
||||||
|
let silence_fraction = self.silence_fraction();
|
||||||
|
if silence_fraction < 0.02 {
|
||||||
|
let penalty = ((0.02 - silence_fraction) / 0.02).min(1.0) * 0.25;
|
||||||
|
score -= penalty as f32;
|
||||||
|
} else if silence_fraction > 0.60 {
|
||||||
|
// Too much silence can also be suspicious (stuffed payloads)
|
||||||
|
let penalty = ((silence_fraction - 0.60) / 0.40).min(1.0) * 0.15;
|
||||||
|
score -= penalty as f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Bitrate deviation penalty
|
||||||
|
if let Some(ratio) = self.bitrate_ratio() {
|
||||||
|
if ratio > 1.20 {
|
||||||
|
let penalty = ((ratio - 1.20) / 0.80).min(1.0) * 0.25;
|
||||||
|
score -= penalty as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Q-flag cadence penalty
|
||||||
|
if let Some(cv) = self.q_flag_cv() {
|
||||||
|
// High variability in Q-flag spacing = suspicious
|
||||||
|
if cv > 0.5 {
|
||||||
|
let penalty = ((cv - 0.5) / 0.5).min(1.0) * 0.15;
|
||||||
|
score -= penalty as f32;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No Q flags seen at all — mildly suspicious after many packets
|
||||||
|
if self.total_packets > 100 {
|
||||||
|
score -= 0.10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Payload-size bimodality bonus/penalty
|
||||||
|
if let Some(bimodality) = self.size_bimodality() {
|
||||||
|
// Bimodality score: 0 = unimodal, 1 = strongly bimodal
|
||||||
|
// Legitimate audio is bimodal (speech + silence)
|
||||||
|
if bimodality < 0.2 {
|
||||||
|
score -= 0.10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(score.clamp(0.0, 1.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map legitimacy score to a [`Verdict`].
|
||||||
|
pub fn verdict(&self) -> Option<Verdict> {
|
||||||
|
self.legitimacy().map(|s| {
|
||||||
|
if s >= 0.7 {
|
||||||
|
Verdict::Legitimate
|
||||||
|
} else if s >= 0.3 {
|
||||||
|
Verdict::Suspect
|
||||||
|
} else {
|
||||||
|
Verdict::Abusive
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
// Feature extractors
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Coefficient of variation of inter-arrival times.
|
||||||
|
fn iat_cov(&self) -> Option<f64> {
|
||||||
|
if self.iat_samples.len() < 10 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mean = self
|
||||||
|
.iat_samples
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.as_secs_f64())
|
||||||
|
.sum::<f64>()
|
||||||
|
/ self.iat_samples.len() as f64;
|
||||||
|
if mean == 0.0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let variance = self
|
||||||
|
.iat_samples
|
||||||
|
.iter()
|
||||||
|
.map(|d| {
|
||||||
|
let diff = d.as_secs_f64() - mean;
|
||||||
|
diff * diff
|
||||||
|
})
|
||||||
|
.sum::<f64>()
|
||||||
|
/ self.iat_samples.len() as f64;
|
||||||
|
let std = variance.sqrt();
|
||||||
|
Some(std / mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fraction of packets that are silence / comfort-noise sized.
|
||||||
|
fn silence_fraction(&self) -> f64 {
|
||||||
|
if self.total_packets == 0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
self.silence_packets as f64 / self.total_packets as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ratio of observed bitrate to nominal bitrate over the 30 s window.
|
||||||
|
fn bitrate_ratio(&self) -> Option<f64> {
|
||||||
|
let codec = self.declared_codec?;
|
||||||
|
let nominal_bps = codec.bitrate_bps() as f64;
|
||||||
|
if nominal_bps == 0.0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let observed_bps = self.window_bytes as f64 * 8.0 / BITRATE_WINDOW_SECS as f64;
|
||||||
|
Some(observed_bps / nominal_bps)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Coefficient of variation of Q-flag intervals.
|
||||||
|
fn q_flag_cv(&self) -> Option<f64> {
|
||||||
|
if self.q_intervals.len() < 3 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mean = self
|
||||||
|
.q_intervals
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.as_secs_f64())
|
||||||
|
.sum::<f64>()
|
||||||
|
/ self.q_intervals.len() as f64;
|
||||||
|
if mean == 0.0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let variance = self
|
||||||
|
.q_intervals
|
||||||
|
.iter()
|
||||||
|
.map(|d| {
|
||||||
|
let diff = d.as_secs_f64() - mean;
|
||||||
|
diff * diff
|
||||||
|
})
|
||||||
|
.sum::<f64>()
|
||||||
|
/ self.q_intervals.len() as f64;
|
||||||
|
let std = variance.sqrt();
|
||||||
|
Some(std / mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple bimodality score based on a 2-bin histogram.
|
||||||
|
///
|
||||||
|
/// Splits payload sizes into "small" (≤ threshold) and "large" bins.
|
||||||
|
/// Returns a score in [0, 1] where 1 = strongly bimodal.
|
||||||
|
fn size_bimodality(&self) -> Option<f64> {
|
||||||
|
if self.size_samples.len() < 20 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let small = self
|
||||||
|
.size_samples
|
||||||
|
.iter()
|
||||||
|
.filter(|&&s| s <= SILENCE_SIZE_THRESHOLD)
|
||||||
|
.count();
|
||||||
|
let large = self.size_samples.len() - small;
|
||||||
|
let total = self.size_samples.len() as f64;
|
||||||
|
let p_small = small as f64 / total;
|
||||||
|
let _p_large = large as f64 / total;
|
||||||
|
// Max bimodality when both bins are equally populated (~0.5 each)
|
||||||
|
let bimodality = 1.0 - (p_small - 0.5).abs() * 2.0;
|
||||||
|
Some(bimodality)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AudioScorer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn audio_header(payload_len: usize, has_quality: bool) -> MediaHeader {
|
||||||
|
MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: if has_quality { 0x40 } else { 0 },
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id: CodecId::Opus24k,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
seq: 0,
|
||||||
|
timestamp: 0,
|
||||||
|
fec_block: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_ignores_video() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let mut h = audio_header(100, false);
|
||||||
|
h.media_type = MediaType::Video;
|
||||||
|
scorer.observe(&h, 100, Instant::now());
|
||||||
|
assert_eq!(scorer.total_packets, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_counts_packets() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
for i in 0..25 {
|
||||||
|
let h = audio_header(100, false);
|
||||||
|
scorer.observe(&h, 100, Instant::now() + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
assert_eq!(scorer.total_packets, 25);
|
||||||
|
assert!(scorer.legitimacy().is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_legitimate_traffic() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
// Simulate 200 packets of legitimate audio:
|
||||||
|
// ~20 ms IAT, mixed speech (100 B) and silence (8 B), periodic Q flags.
|
||||||
|
for i in 0..200 {
|
||||||
|
let payload = if i % 3 == 0 { 8 } else { 100 };
|
||||||
|
let has_q = i % 10 == 0;
|
||||||
|
let h = audio_header(payload, has_q);
|
||||||
|
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let leg = scorer.legitimacy().unwrap();
|
||||||
|
assert!(
|
||||||
|
leg >= 0.7,
|
||||||
|
"legitimate traffic should score ≥ 0.7, got {leg}"
|
||||||
|
);
|
||||||
|
assert_eq!(scorer.verdict(), Some(Verdict::Legitimate));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_abusive_uniform_iat() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
// Uniform IAT (no jitter), all same size, no Q flags — tunnel-like
|
||||||
|
for i in 0..200 {
|
||||||
|
let h = audio_header(200, false);
|
||||||
|
scorer.observe(&h, 200, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let leg = scorer.legitimacy().unwrap();
|
||||||
|
assert!(
|
||||||
|
leg < 0.6,
|
||||||
|
"uniform tunnel-like traffic should score < 0.6, got {leg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_abusive_no_silence() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
// No silence packets at all, very regular IAT
|
||||||
|
for i in 0..200 {
|
||||||
|
let h = audio_header(150, false);
|
||||||
|
scorer.observe(&h, 150, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let leg = scorer.legitimacy().unwrap();
|
||||||
|
assert!(
|
||||||
|
leg < 0.6,
|
||||||
|
"no-silence traffic should score < 0.6, got {leg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn audio_scorer_insufficient_samples() {
|
||||||
|
let scorer = AudioScorer::new();
|
||||||
|
assert_eq!(scorer.legitimacy(), None);
|
||||||
|
assert_eq!(scorer.verdict(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn silence_fraction_computed_correctly() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
for i in 0..100 {
|
||||||
|
let payload = if i < 30 { 8 } else { 100 };
|
||||||
|
let h = audio_header(payload, false);
|
||||||
|
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
assert!((scorer.silence_fraction() - 0.30).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bitrate_ratio_saturates_when_no_codec() {
|
||||||
|
let scorer = AudioScorer::new();
|
||||||
|
assert_eq!(scorer.bitrate_ratio(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_flag_cv_regular_spacing() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
for i in 0..50 {
|
||||||
|
let has_q = i % 5 == 0;
|
||||||
|
let h = audio_header(100, has_q);
|
||||||
|
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let cv = scorer.q_flag_cv().unwrap();
|
||||||
|
assert!(
|
||||||
|
cv < 0.1,
|
||||||
|
"regular Q-flag spacing should have CV < 0.1, got {cv}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn size_bimodality_for_mixed_traffic() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
for i in 0..100 {
|
||||||
|
let payload = if i % 2 == 0 { 8 } else { 120 };
|
||||||
|
let h = audio_header(payload, false);
|
||||||
|
scorer.observe(&h, payload, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let bim = scorer.size_bimodality().unwrap();
|
||||||
|
assert!(
|
||||||
|
bim > 0.8,
|
||||||
|
"perfectly mixed small/large should be highly bimodal, got {bim}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn size_bimodality_for_uniform_traffic() {
|
||||||
|
let mut scorer = AudioScorer::new();
|
||||||
|
let base = Instant::now();
|
||||||
|
for i in 0..100 {
|
||||||
|
let h = audio_header(100, false);
|
||||||
|
scorer.observe(&h, 100, base + Duration::from_millis(i * 20));
|
||||||
|
}
|
||||||
|
let bim = scorer.size_bimodality().unwrap();
|
||||||
|
assert!(
|
||||||
|
bim < 0.3,
|
||||||
|
"uniform size traffic should be unimodal, got {bim}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,10 +32,7 @@ pub struct AuthenticatedClient {
|
|||||||
///
|
///
|
||||||
/// Calls `POST {auth_url}` with `{ "token": "..." }`.
|
/// Calls `POST {auth_url}` with `{ "token": "..." }`.
|
||||||
/// Returns the client identity if valid, or an error string.
|
/// Returns the client identity if valid, or an error string.
|
||||||
pub async fn validate_token(
|
pub async fn validate_token(auth_url: &str, token: &str) -> Result<AuthenticatedClient, String> {
|
||||||
auth_url: &str,
|
|
||||||
token: &str,
|
|
||||||
) -> Result<AuthenticatedClient, String> {
|
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(5))
|
.timeout(std::time::Duration::from_secs(5))
|
||||||
.build()
|
.build()
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ pub struct DirectCall {
|
|||||||
/// interface addresses from the `DirectCallAnswer`. Cross-
|
/// interface addresses from the `DirectCallAnswer`. Cross-
|
||||||
/// wired into the caller's `CallSetup.peer_local_addrs`.
|
/// wired into the caller's `CallSetup.peer_local_addrs`.
|
||||||
pub callee_local_addrs: Vec<String>,
|
pub callee_local_addrs: Vec<String>,
|
||||||
|
/// Phase 8 (Tailscale-inspired): caller's port-mapped
|
||||||
|
/// external address from NAT-PMP/PCP/UPnP. Cross-wired
|
||||||
|
/// into callee's `CallSetup.peer_mapped_addr`.
|
||||||
|
pub caller_mapped_addr: Option<String>,
|
||||||
|
/// Phase 8: callee's port-mapped external address.
|
||||||
|
/// Cross-wired into caller's `CallSetup.peer_mapped_addr`.
|
||||||
|
pub callee_mapped_addr: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Registry of active direct calls.
|
/// Registry of active direct calls.
|
||||||
@@ -76,7 +83,12 @@ impl CallRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new pending call. Returns the call_id.
|
/// Create a new pending call. Returns the call_id.
|
||||||
pub fn create_call(&mut self, call_id: String, caller_fp: String, callee_fp: String) -> &DirectCall {
|
pub fn create_call(
|
||||||
|
&mut self,
|
||||||
|
call_id: String,
|
||||||
|
caller_fp: String,
|
||||||
|
callee_fp: String,
|
||||||
|
) -> &DirectCall {
|
||||||
let call = DirectCall {
|
let call = DirectCall {
|
||||||
call_id: call_id.clone(),
|
call_id: call_id.clone(),
|
||||||
caller_fingerprint: caller_fp,
|
caller_fingerprint: caller_fp,
|
||||||
@@ -92,6 +104,8 @@ impl CallRegistry {
|
|||||||
peer_relay_fp: None,
|
peer_relay_fp: None,
|
||||||
caller_local_addrs: Vec::new(),
|
caller_local_addrs: Vec::new(),
|
||||||
callee_local_addrs: Vec::new(),
|
callee_local_addrs: Vec::new(),
|
||||||
|
caller_mapped_addr: None,
|
||||||
|
callee_mapped_addr: None,
|
||||||
};
|
};
|
||||||
self.calls.insert(call_id.clone(), call);
|
self.calls.insert(call_id.clone(), call);
|
||||||
self.calls.get(&call_id).unwrap()
|
self.calls.get(&call_id).unwrap()
|
||||||
@@ -142,6 +156,22 @@ impl CallRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Phase 8: stash the caller's port-mapped address from
|
||||||
|
/// the `DirectCallOffer`.
|
||||||
|
pub fn set_caller_mapped_addr(&mut self, call_id: &str, addr: Option<String>) {
|
||||||
|
if let Some(call) = self.calls.get_mut(call_id) {
|
||||||
|
call.caller_mapped_addr = addr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Phase 8: stash the callee's port-mapped address from
|
||||||
|
/// the `DirectCallAnswer`.
|
||||||
|
pub fn set_callee_mapped_addr(&mut self, call_id: &str, addr: Option<String>) {
|
||||||
|
if let Some(call) = self.calls.get_mut(call_id) {
|
||||||
|
call.callee_mapped_addr = addr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a call by ID.
|
/// Get a call by ID.
|
||||||
pub fn get(&self, call_id: &str) -> Option<&DirectCall> {
|
pub fn get(&self, call_id: &str) -> Option<&DirectCall> {
|
||||||
self.calls.get(call_id)
|
self.calls.get(call_id)
|
||||||
@@ -164,7 +194,12 @@ impl CallRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Transition to Active state.
|
/// Transition to Active state.
|
||||||
pub fn set_active(&mut self, call_id: &str, mode: wzp_proto::CallAcceptMode, room: String) -> bool {
|
pub fn set_active(
|
||||||
|
&mut self,
|
||||||
|
call_id: &str,
|
||||||
|
mode: wzp_proto::CallAcceptMode,
|
||||||
|
room: String,
|
||||||
|
) -> bool {
|
||||||
if let Some(call) = self.calls.get_mut(call_id) {
|
if let Some(call) = self.calls.get_mut(call_id) {
|
||||||
if call.state == DirectCallState::Pending || call.state == DirectCallState::Ringing {
|
if call.state == DirectCallState::Pending || call.state == DirectCallState::Ringing {
|
||||||
call.state = DirectCallState::Active;
|
call.state = DirectCallState::Active;
|
||||||
@@ -188,7 +223,8 @@ impl CallRegistry {
|
|||||||
|
|
||||||
/// Find active/pending calls involving a fingerprint.
|
/// Find active/pending calls involving a fingerprint.
|
||||||
pub fn calls_for_fingerprint(&self, fp: &str) -> Vec<&DirectCall> {
|
pub fn calls_for_fingerprint(&self, fp: &str) -> Vec<&DirectCall> {
|
||||||
self.calls.values()
|
self.calls
|
||||||
|
.values()
|
||||||
.filter(|c| {
|
.filter(|c| {
|
||||||
c.state != DirectCallState::Ended
|
c.state != DirectCallState::Ended
|
||||||
&& (c.caller_fingerprint == fp || c.callee_fingerprint == fp)
|
&& (c.caller_fingerprint == fp || c.callee_fingerprint == fp)
|
||||||
@@ -211,22 +247,25 @@ impl CallRegistry {
|
|||||||
/// Returns call IDs of expired calls.
|
/// Returns call IDs of expired calls.
|
||||||
pub fn expire_stale(&mut self, timeout: Duration) -> Vec<DirectCall> {
|
pub fn expire_stale(&mut self, timeout: Duration) -> Vec<DirectCall> {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let expired: Vec<String> = self.calls.iter()
|
let expired: Vec<String> = self
|
||||||
|
.calls
|
||||||
|
.iter()
|
||||||
.filter(|(_, c)| {
|
.filter(|(_, c)| {
|
||||||
c.state == DirectCallState::Pending
|
c.state == DirectCallState::Pending && now.duration_since(c.created_at) > timeout
|
||||||
&& now.duration_since(c.created_at) > timeout
|
|
||||||
})
|
})
|
||||||
.map(|(id, _)| id.clone())
|
.map(|(id, _)| id.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
expired.into_iter()
|
expired
|
||||||
|
.into_iter()
|
||||||
.filter_map(|id| self.calls.remove(&id))
|
.filter_map(|id| self.calls.remove(&id))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Number of active (non-ended) calls.
|
/// Number of active (non-ended) calls.
|
||||||
pub fn active_count(&self) -> usize {
|
pub fn active_count(&self) -> usize {
|
||||||
self.calls.values()
|
self.calls
|
||||||
|
.values()
|
||||||
.filter(|c| c.state != DirectCallState::Ended)
|
.filter(|c| c.state != DirectCallState::Ended)
|
||||||
.count()
|
.count()
|
||||||
}
|
}
|
||||||
@@ -245,9 +284,16 @@ mod tests {
|
|||||||
assert!(reg.set_ringing("c1"));
|
assert!(reg.set_ringing("c1"));
|
||||||
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Ringing);
|
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Ringing);
|
||||||
|
|
||||||
assert!(reg.set_active("c1", wzp_proto::CallAcceptMode::AcceptGeneric, "_call:c1".into()));
|
assert!(reg.set_active(
|
||||||
|
"c1",
|
||||||
|
wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||||
|
"_call:c1".into()
|
||||||
|
));
|
||||||
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Active);
|
assert_eq!(reg.get("c1").unwrap().state, DirectCallState::Active);
|
||||||
assert_eq!(reg.get("c1").unwrap().room_name.as_deref(), Some("_call:c1"));
|
assert_eq!(
|
||||||
|
reg.get("c1").unwrap().room_name.as_deref(),
|
||||||
|
Some("_call:c1")
|
||||||
|
);
|
||||||
|
|
||||||
let ended = reg.end_call("c1").unwrap();
|
let ended = reg.end_call("c1").unwrap();
|
||||||
assert_eq!(ended.state, DirectCallState::Ended);
|
assert_eq!(ended.state, DirectCallState::Ended);
|
||||||
@@ -304,10 +350,7 @@ mod tests {
|
|||||||
// Both addrs are independently readable — the relay uses
|
// Both addrs are independently readable — the relay uses
|
||||||
// them to cross-wire peer_direct_addr in CallSetup.
|
// them to cross-wire peer_direct_addr in CallSetup.
|
||||||
let c = reg.get("c1").unwrap();
|
let c = reg.get("c1").unwrap();
|
||||||
assert_eq!(
|
assert_eq!(c.caller_reflexive_addr.as_deref(), Some("192.0.2.1:4433"));
|
||||||
c.caller_reflexive_addr.as_deref(),
|
|
||||||
Some("192.0.2.1:4433")
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
c.callee_reflexive_addr.as_deref(),
|
c.callee_reflexive_addr.as_deref(),
|
||||||
Some("198.51.100.9:4433")
|
Some("198.51.100.9:4433")
|
||||||
@@ -340,6 +383,49 @@ mod tests {
|
|||||||
reg.set_peer_relay_fp("does-not-exist", Some("x".into()));
|
reg.set_peer_relay_fp("does-not-exist", Some("x".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn call_registry_stores_mapped_addrs() {
|
||||||
|
let mut reg = CallRegistry::new();
|
||||||
|
reg.create_call("c1".into(), "alice".into(), "bob".into());
|
||||||
|
|
||||||
|
// Default: both mapped addrs are None.
|
||||||
|
let c = reg.get("c1").unwrap();
|
||||||
|
assert!(c.caller_mapped_addr.is_none());
|
||||||
|
assert!(c.callee_mapped_addr.is_none());
|
||||||
|
|
||||||
|
// Caller advertises its port-mapped addr via DirectCallOffer.
|
||||||
|
reg.set_caller_mapped_addr("c1", Some("203.0.113.5:12345".into()));
|
||||||
|
assert_eq!(
|
||||||
|
reg.get("c1").unwrap().caller_mapped_addr.as_deref(),
|
||||||
|
Some("203.0.113.5:12345")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Callee responds with its mapped addr.
|
||||||
|
reg.set_callee_mapped_addr("c1", Some("198.51.100.9:54321".into()));
|
||||||
|
assert_eq!(
|
||||||
|
reg.get("c1").unwrap().callee_mapped_addr.as_deref(),
|
||||||
|
Some("198.51.100.9:54321")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Both addrs readable — relay uses them to cross-wire
|
||||||
|
// peer_mapped_addr in CallSetup.
|
||||||
|
let c = reg.get("c1").unwrap();
|
||||||
|
assert_eq!(c.caller_mapped_addr.as_deref(), Some("203.0.113.5:12345"));
|
||||||
|
assert_eq!(c.callee_mapped_addr.as_deref(), Some("198.51.100.9:54321"));
|
||||||
|
|
||||||
|
// Setter on unknown call is a no-op.
|
||||||
|
reg.set_caller_mapped_addr("nope", Some("x".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn call_registry_clearing_mapped_addr_works() {
|
||||||
|
let mut reg = CallRegistry::new();
|
||||||
|
reg.create_call("c1".into(), "alice".into(), "bob".into());
|
||||||
|
reg.set_caller_mapped_addr("c1", Some("1.2.3.4:5".into()));
|
||||||
|
reg.set_caller_mapped_addr("c1", None);
|
||||||
|
assert!(reg.get("c1").unwrap().caller_mapped_addr.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn call_registry_clearing_reflex_addr_works() {
|
fn call_registry_clearing_reflex_addr_works() {
|
||||||
// Passing None to the setter must clear a previously-set value
|
// Passing None to the setter must clear a previously-set value
|
||||||
|
|||||||
@@ -87,6 +87,14 @@ pub struct RelayConfig {
|
|||||||
/// Unlike [[peers]], no url is needed — the peer connects to us.
|
/// Unlike [[peers]], no url is needed — the peer connects to us.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub trusted: Vec<TrustedConfig>,
|
pub trusted: Vec<TrustedConfig>,
|
||||||
|
/// Phase 8: geographic region identifier (e.g., "us-east", "eu-west").
|
||||||
|
/// Sent to clients in `RegisterPresenceAck.relay_region` so they can
|
||||||
|
/// build a relay map for automatic selection.
|
||||||
|
pub region: Option<String>,
|
||||||
|
/// Phase 8: externally-advertised address for this relay. Used to
|
||||||
|
/// populate `available_relays` in `RegisterPresenceAck`. If not set,
|
||||||
|
/// `listen_addr` is used.
|
||||||
|
pub advertised_addr: Option<SocketAddr>,
|
||||||
/// Debug tap: log packet headers for matching rooms ("*" = all rooms).
|
/// Debug tap: log packet headers for matching rooms ("*" = all rooms).
|
||||||
/// Activated via --debug-tap <room> or debug_tap = "room" in TOML.
|
/// Activated via --debug-tap <room> or debug_tap = "room" in TOML.
|
||||||
pub debug_tap: Option<String>,
|
pub debug_tap: Option<String>,
|
||||||
@@ -114,6 +122,8 @@ impl Default for RelayConfig {
|
|||||||
peers: Vec::new(),
|
peers: Vec::new(),
|
||||||
global_rooms: Vec::new(),
|
global_rooms: Vec::new(),
|
||||||
trusted: Vec::new(),
|
trusted: Vec::new(),
|
||||||
|
region: None,
|
||||||
|
advertised_addr: None,
|
||||||
debug_tap: None,
|
debug_tap: None,
|
||||||
event_log: None,
|
event_log: None,
|
||||||
}
|
}
|
||||||
@@ -135,7 +145,10 @@ pub struct RelayInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Load config from path, or create a personalized example config if it doesn't exist.
|
/// Load config from path, or create a personalized example config if it doesn't exist.
|
||||||
pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<RelayConfig, anyhow::Error> {
|
pub fn load_or_create_config(
|
||||||
|
path: &str,
|
||||||
|
info: Option<&RelayInfo>,
|
||||||
|
) -> Result<RelayConfig, anyhow::Error> {
|
||||||
let p = std::path::Path::new(path);
|
let p = std::path::Path::new(path);
|
||||||
if p.exists() {
|
if p.exists() {
|
||||||
return load_config(path);
|
return load_config(path);
|
||||||
@@ -154,7 +167,9 @@ pub fn load_or_create_config(path: &str, info: Option<&RelayInfo>) -> Result<Rel
|
|||||||
|
|
||||||
/// Generate an example TOML config, personalized with this relay's info if available.
|
/// Generate an example TOML config, personalized with this relay's info if available.
|
||||||
fn generate_example_config(info: Option<&RelayInfo>) -> String {
|
fn generate_example_config(info: Option<&RelayInfo>) -> String {
|
||||||
let listen = info.map(|i| i.listen_addr.as_str()).unwrap_or("0.0.0.0:4433");
|
let listen = info
|
||||||
|
.map(|i| i.listen_addr.as_str())
|
||||||
|
.unwrap_or("0.0.0.0:4433");
|
||||||
let peer_example = if let Some(i) = info {
|
let peer_example = if let Some(i) = info {
|
||||||
let ip = i.public_ip.as_deref().unwrap_or("this-relay-ip");
|
let ip = i.public_ip.as_deref().unwrap_or("this-relay-ip");
|
||||||
format!(
|
format!(
|
||||||
|
|||||||
544
crates/wzp-relay/src/conformance.rs
Normal file
544
crates/wzp-relay/src/conformance.rs
Normal file
@@ -0,0 +1,544 @@
|
|||||||
|
//! Relay conformance metering — Tier A/B/C/D/E enforcement.
|
||||||
|
//!
|
||||||
|
//! Each participant gets a [`ConformanceMeter`] that tracks per-second
|
||||||
|
//! traffic against the declared codec's nominal bitrate ceiling.
|
||||||
|
//! Violations are logged and counted but do **not** drop packets
|
||||||
|
//! (observe-only mode).
|
||||||
|
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use wzp_proto::{CodecId, MediaHeader};
|
||||||
|
|
||||||
|
/// Rolling window size for timestamp-drift detection (Tier C).
|
||||||
|
const DRIFT_WINDOW_SIZE: usize = 200;
|
||||||
|
|
||||||
|
/// Kinds of conformance violation detected by the relay.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum Violation {
|
||||||
|
/// Cumulative bitrate in the current 1 s window exceeds the Tier A ceiling.
|
||||||
|
BitrateExceeded,
|
||||||
|
/// Packet rate exceeds the per-codec safety limit (Tier B).
|
||||||
|
PacketRateExceeded,
|
||||||
|
/// Timestamp jumped backwards or forwards suspiciously (Tier C).
|
||||||
|
TimestampDrift,
|
||||||
|
/// Sustained payload size exceeds 2× the typical bound for the declared codec (Tier D).
|
||||||
|
PayloadSizeExceeded,
|
||||||
|
/// Per-session token-bucket rate cap exceeded (Tier E).
|
||||||
|
RateCapExceeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error type returned when a [`TokenBucket`] does not hold enough tokens.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct TokenExhausted;
|
||||||
|
|
||||||
|
/// Simple token bucket for per-session rate capping (Tier E).
|
||||||
|
///
|
||||||
|
/// Tokens represent bytes. The bucket refills at `refill_per_sec` bytes per
|
||||||
|
/// second, up to `capacity`. A packet is allowed only if the bucket holds
|
||||||
|
/// enough tokens for its size.
|
||||||
|
pub struct TokenBucket {
|
||||||
|
capacity: u64,
|
||||||
|
tokens: f64,
|
||||||
|
refill_per_sec: u64,
|
||||||
|
last_refill: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenBucket {
|
||||||
|
/// Create a new bucket with the given byte capacity and refill rate.
|
||||||
|
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
capacity,
|
||||||
|
tokens: capacity as f64,
|
||||||
|
refill_per_sec,
|
||||||
|
last_refill: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-session audio cap: 256 kbps with 30 s @ 2× burst.
|
||||||
|
/// Capacity = 30 s × 64 KB/s = 1_920_000 bytes.
|
||||||
|
pub fn for_audio_session() -> Self {
|
||||||
|
let refill_per_sec = 256_000 / 8; // 32_000 bytes/sec
|
||||||
|
let capacity = refill_per_sec * 30 * 2; // 1_920_000 bytes
|
||||||
|
Self::new(capacity, refill_per_sec)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt to consume `bytes` from the bucket.
|
||||||
|
///
|
||||||
|
/// Refills based on elapsed time since the last call, then deducts the
|
||||||
|
/// cost. Returns `Ok(())` if enough tokens were available,
|
||||||
|
/// `Err(TokenExhausted)` otherwise.
|
||||||
|
pub fn try_consume(&mut self, bytes: u64, now: Instant) -> Result<(), TokenExhausted> {
|
||||||
|
let elapsed = now.duration_since(self.last_refill);
|
||||||
|
self.last_refill = now;
|
||||||
|
self.tokens += elapsed.as_secs_f64() * self.refill_per_sec as f64;
|
||||||
|
if self.tokens > self.capacity as f64 {
|
||||||
|
self.tokens = self.capacity as f64;
|
||||||
|
}
|
||||||
|
if self.tokens >= bytes as f64 {
|
||||||
|
self.tokens -= bytes as f64;
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(TokenExhausted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-participant traffic conformance meter.
|
||||||
|
pub struct ConformanceMeter {
|
||||||
|
window_start: Instant,
|
||||||
|
bytes_in_window: u64,
|
||||||
|
packets_in_window: u64,
|
||||||
|
/// Rolling (seq, timestamp) pairs for drift detection.
|
||||||
|
drift_window: VecDeque<(u32, u32)>,
|
||||||
|
/// EWMA of payload size for Tier D sanity checks.
|
||||||
|
ewma_payload_size: f64,
|
||||||
|
/// Optional token bucket for Tier E per-session rate cap.
|
||||||
|
token_bucket: Option<TokenBucket>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConformanceMeter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
window_start: Instant::now(),
|
||||||
|
bytes_in_window: 0,
|
||||||
|
packets_in_window: 0,
|
||||||
|
drift_window: VecDeque::with_capacity(DRIFT_WINDOW_SIZE),
|
||||||
|
ewma_payload_size: 0.0,
|
||||||
|
token_bucket: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a meter with a Tier E token bucket for per-session rate capping.
|
||||||
|
pub fn with_token_bucket(bucket: TokenBucket) -> Self {
|
||||||
|
let mut meter = Self::new();
|
||||||
|
meter.token_bucket = Some(bucket);
|
||||||
|
meter
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inspect an incoming media packet and accumulate it against the
|
||||||
|
/// current 1-second window. Returns [`Err(Violation)`] when a limit
|
||||||
|
/// is crossed.
|
||||||
|
pub fn observe(
|
||||||
|
&mut self,
|
||||||
|
header: &MediaHeader,
|
||||||
|
payload_len: usize,
|
||||||
|
now: Instant,
|
||||||
|
) -> Result<(), Violation> {
|
||||||
|
// Roll the window forward if a second has elapsed.
|
||||||
|
if now.duration_since(self.window_start) >= Duration::from_secs(1) {
|
||||||
|
self.window_start = now;
|
||||||
|
self.bytes_in_window = 0;
|
||||||
|
self.packets_in_window = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||||
|
self.bytes_in_window += packet_size;
|
||||||
|
self.packets_in_window += 1;
|
||||||
|
|
||||||
|
// Tier A — bitrate ceiling.
|
||||||
|
let ceiling = ceiling_bps(header.codec_id);
|
||||||
|
let max_bytes_per_sec = ceiling / 8;
|
||||||
|
if self.bytes_in_window > max_bytes_per_sec {
|
||||||
|
return Err(Violation::BitrateExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tier B — packet-rate ceiling.
|
||||||
|
let max_pps = max_pps(header.codec_id);
|
||||||
|
let pps_threshold = (max_pps as f32 * 1.5) as u64;
|
||||||
|
if self.packets_in_window > pps_threshold {
|
||||||
|
return Err(Violation::PacketRateExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tier C — timestamp drift.
|
||||||
|
self.drift_window.push_back((header.seq, header.timestamp));
|
||||||
|
if self.drift_window.len() > DRIFT_WINDOW_SIZE {
|
||||||
|
self.drift_window.pop_front();
|
||||||
|
}
|
||||||
|
if self.drift_window.len() >= 2 {
|
||||||
|
let (first_seq, first_ts) = self.drift_window.front().copied().unwrap();
|
||||||
|
let (last_seq, last_ts) = self.drift_window.back().copied().unwrap();
|
||||||
|
|
||||||
|
let ds = last_seq.wrapping_sub(first_seq) as f64;
|
||||||
|
let dt = last_ts.wrapping_sub(first_ts) as f64;
|
||||||
|
|
||||||
|
if ds > 0.0 {
|
||||||
|
let avg_ms_per_packet = dt / ds;
|
||||||
|
let frame_ms = header.codec_id.frame_duration_ms() as f64;
|
||||||
|
let min_ratio = frame_ms * 0.5;
|
||||||
|
let max_ratio = frame_ms * 2.0;
|
||||||
|
if avg_ms_per_packet < min_ratio || avg_ms_per_packet > max_ratio {
|
||||||
|
return Err(Violation::TimestampDrift);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tier D — payload-size sanity (EWMA).
|
||||||
|
let alpha = 0.05; // ~20-packet smoothing
|
||||||
|
self.ewma_payload_size =
|
||||||
|
alpha * payload_len as f64 + (1.0 - alpha) * self.ewma_payload_size;
|
||||||
|
let bound = payload_size_bound(header.codec_id);
|
||||||
|
if self.ewma_payload_size > (bound * 2) as f64 {
|
||||||
|
return Err(Violation::PayloadSizeExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tier E — per-session token-bucket rate cap.
|
||||||
|
if let Some(ref mut bucket) = self.token_bucket {
|
||||||
|
let packet_size = (MediaHeader::WIRE_SIZE + payload_len) as u64;
|
||||||
|
if bucket.try_consume(packet_size, now).is_err() {
|
||||||
|
return Err(Violation::RateCapExceeded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConformanceMeter {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the Tier A bitrate ceiling for a given codec.
|
||||||
|
///
|
||||||
|
/// Formula:
|
||||||
|
/// nominal_bitrate * 3 (FEC 2.0 overhead) * 115 / 100 (15% safety margin)
|
||||||
|
/// with a floor of 2 kbps.
|
||||||
|
pub fn ceiling_bps(codec: CodecId) -> u64 {
|
||||||
|
let nominal = codec.bitrate_bps() as u64;
|
||||||
|
(nominal * 3 * 115 / 100).max(2_000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the Tier B packet-rate ceiling for a given codec.
|
||||||
|
///
|
||||||
|
/// Formula:
|
||||||
|
/// 1000 / frame_duration_ms * 3 (FEC overhead factor)
|
||||||
|
pub fn max_pps(codec: CodecId) -> u32 {
|
||||||
|
let fd = codec.frame_duration_ms() as u32;
|
||||||
|
if fd == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
(1000 / fd) * 3
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Typical per-codec payload size bound in bytes (Tier D).
|
||||||
|
///
|
||||||
|
/// These are empirical upper bounds for a single audio frame at the codec's
|
||||||
|
/// nominal configuration. The EWMA must not exceed 2× this value.
|
||||||
|
pub fn payload_size_bound(codec: CodecId) -> usize {
|
||||||
|
match codec {
|
||||||
|
CodecId::Opus64k => 320,
|
||||||
|
CodecId::Opus48k => 240,
|
||||||
|
CodecId::Opus32k => 200,
|
||||||
|
CodecId::Opus24k => 160,
|
||||||
|
CodecId::Opus16k => 100,
|
||||||
|
CodecId::Opus6k => 90,
|
||||||
|
CodecId::Codec2_3200 => 30,
|
||||||
|
CodecId::Codec2_1200 => 30,
|
||||||
|
CodecId::ComfortNoise => 16,
|
||||||
|
CodecId::H264Baseline | CodecId::H265Main | CodecId::Av1Main => 1400,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use wzp_proto::MediaType;
|
||||||
|
|
||||||
|
fn make_header(codec_id: CodecId) -> MediaHeader {
|
||||||
|
MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id,
|
||||||
|
seq: 0,
|
||||||
|
timestamp: 0,
|
||||||
|
fec_block: 0,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_header_with_seq_ts(codec_id: CodecId, seq: u32, timestamp: u32) -> MediaHeader {
|
||||||
|
MediaHeader {
|
||||||
|
version: 2,
|
||||||
|
flags: 0,
|
||||||
|
media_type: MediaType::Audio,
|
||||||
|
codec_id,
|
||||||
|
seq,
|
||||||
|
timestamp,
|
||||||
|
fec_block: 0,
|
||||||
|
stream_id: 0,
|
||||||
|
fec_ratio: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bitrate_exceeded_for_opus24k() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let header = make_header(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Ceiling for Opus24k = 24_000 * 3 * 115 / 100 = 82_800 bps
|
||||||
|
// = 10_350 bytes/sec. 1 MB/s = 125_000 bytes/packet will blow past
|
||||||
|
// that in a single packet.
|
||||||
|
let now = Instant::now();
|
||||||
|
let result = meter.observe(&header, 1_000_000, now);
|
||||||
|
assert_eq!(result, Err(Violation::BitrateExceeded));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn small_packets_stay_within_ceiling() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let header = make_header(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Ceiling = 82_800 bps = 10_350 bytes/sec.
|
||||||
|
// Each packet = 16-byte header + 80 bytes = 96 bytes.
|
||||||
|
// 100 packets = 9_600 bytes < 10_350.
|
||||||
|
let now = Instant::now();
|
||||||
|
for _ in 0..100 {
|
||||||
|
assert!(meter.observe(&header, 80, now).is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn window_resets_after_one_second() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let header = make_header(CodecId::Opus24k);
|
||||||
|
|
||||||
|
// Fill the window to just under the limit.
|
||||||
|
// Use 300-byte payloads (under Tier D 2× bound of 320 for Opus24k).
|
||||||
|
let t0 = Instant::now();
|
||||||
|
for _ in 0..32 {
|
||||||
|
assert!(meter.observe(&header, 300, t0).is_ok());
|
||||||
|
}
|
||||||
|
// 32 * (header wire size + 300) ≈ 32 * 316 = 10_112 bytes < 10_350
|
||||||
|
|
||||||
|
// Same packets 1.1 seconds later should be fine because the window
|
||||||
|
// rolls over.
|
||||||
|
let t1 = t0 + Duration::from_millis(1_100);
|
||||||
|
for _ in 0..32 {
|
||||||
|
assert!(meter.observe(&header, 300, t1).is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ceiling_bps_floor() {
|
||||||
|
// ComfortNoise has 0 nominal bitrate, so the floor kicks in.
|
||||||
|
assert_eq!(ceiling_bps(CodecId::ComfortNoise), 2_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
// Tier B — packet rate
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn packet_rate_exceeded() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
// Opus24k: max_pps = 1000/20 * 3 = 150. Threshold = 150 * 1.5 = 225.
|
||||||
|
let header = make_header(CodecId::Opus24k);
|
||||||
|
let now = Instant::now();
|
||||||
|
for _ in 0..225 {
|
||||||
|
assert!(meter.observe(&header, 10, now).is_ok());
|
||||||
|
}
|
||||||
|
// 226th packet should trip the limit.
|
||||||
|
assert_eq!(
|
||||||
|
meter.observe(&header, 10, now),
|
||||||
|
Err(Violation::PacketRateExceeded)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn packet_rate_within_limit() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
// Opus6k: max_pps = 1000/40 * 3 = 75. Threshold = 75 * 1.5 = 112.
|
||||||
|
// Use 0-byte payload so bitrate ceiling (2_587 bytes/sec) is not the
|
||||||
|
// limiting factor. 112 packets × 16 bytes = 1_792 bytes < 2_587.
|
||||||
|
let header = make_header(CodecId::Opus6k);
|
||||||
|
let now = Instant::now();
|
||||||
|
for _ in 0..112 {
|
||||||
|
assert!(meter.observe(&header, 0, now).is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
// Tier C — timestamp drift
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn timestamp_drift_detected_when_too_fast() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
// Opus24k frame_duration = 20 ms.
|
||||||
|
// Acceptable range: [10, 40] ms per packet.
|
||||||
|
// Send packets with timestamp advancing by 5 ms each (too fast).
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut drift_seen = false;
|
||||||
|
for i in 0..200 {
|
||||||
|
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 5);
|
||||||
|
match meter.observe(&header, 10, now) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(Violation::TimestampDrift) => drift_seen = true,
|
||||||
|
Err(other) => panic!("unexpected violation: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(drift_seen, "expected TimestampDrift to be detected");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn timestamp_drift_detected_when_too_slow() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
// Opus24k frame_duration = 20 ms.
|
||||||
|
// Acceptable range: [10, 40] ms per packet.
|
||||||
|
// Send packets with timestamp advancing by 50 ms each (too slow).
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut drift_seen = false;
|
||||||
|
for i in 0..200 {
|
||||||
|
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 50);
|
||||||
|
match meter.observe(&header, 10, now) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(Violation::TimestampDrift) => drift_seen = true,
|
||||||
|
Err(other) => panic!("unexpected violation: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(drift_seen, "expected TimestampDrift to be detected");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn timestamp_normal_no_drift() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
// Opus24k frame_duration = 20 ms.
|
||||||
|
// Send 200 packets with timestamp advancing by exactly 20 ms each.
|
||||||
|
let now = Instant::now();
|
||||||
|
for i in 0..200 {
|
||||||
|
let header = make_header_with_seq_ts(CodecId::Opus24k, i, i * 20);
|
||||||
|
assert!(meter.observe(&header, 10, now).is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn timestamp_drift_not_checked_before_two_packets() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let now = Instant::now();
|
||||||
|
// Single packet with wild timestamp — should not trigger drift.
|
||||||
|
let header = make_header_with_seq_ts(CodecId::Opus24k, 0, 999_999);
|
||||||
|
assert!(meter.observe(&header, 10, now).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
// Tier D — payload-size sanity
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conformance_tier_d() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let header = make_header(CodecId::Codec2_1200);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Codec2_1200 bound = 30 bytes. 2× bound = 60 bytes.
|
||||||
|
// Feed 1400-byte payloads — EWMA should cross 60 within a few packets.
|
||||||
|
let mut flagged = false;
|
||||||
|
for _ in 0..200 {
|
||||||
|
if meter.observe(&header, 1400, now).is_err() {
|
||||||
|
flagged = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(
|
||||||
|
flagged,
|
||||||
|
"expected PayloadSizeExceeded for 1400-byte Codec2_1200 payloads"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn payload_size_normal_stays_within_bound() {
|
||||||
|
let mut meter = ConformanceMeter::new();
|
||||||
|
let header = make_header(CodecId::Opus24k);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Opus24k bound = 160 bytes. 2× bound = 320 bytes.
|
||||||
|
// Feed 150-byte payloads — well within the 2× limit.
|
||||||
|
// Limit to 10 packets so the 1-second bitrate window (10_350 bytes)
|
||||||
|
// is not exhausted: 10 * (16 + 150) = 1_660 < 10_350.
|
||||||
|
for _ in 0..10 {
|
||||||
|
assert!(
|
||||||
|
meter.observe(&header, 150, now).is_ok(),
|
||||||
|
"150-byte Opus24k payloads should stay within Tier D limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
// Tier E — token-bucket rate cap
|
||||||
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_small_burst_ok() {
|
||||||
|
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||||
|
let now = Instant::now();
|
||||||
|
// 50 KB burst fits inside 100 KB capacity.
|
||||||
|
assert!(bucket.try_consume(50_000, now).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_large_burst_fails() {
|
||||||
|
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||||
|
let now = Instant::now();
|
||||||
|
// 1 MB exceeds 100 KB capacity.
|
||||||
|
assert!(bucket.try_consume(1_000_000, now).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_refills_over_time() {
|
||||||
|
let mut bucket = TokenBucket::new(100_000, 32_000);
|
||||||
|
let t0 = Instant::now();
|
||||||
|
// Drain the bucket.
|
||||||
|
assert!(bucket.try_consume(100_000, t0).is_ok());
|
||||||
|
// Immediately try again — should fail.
|
||||||
|
assert!(bucket.try_consume(10_000, t0).is_err());
|
||||||
|
// Wait 1 second — bucket refills 32_000 bytes.
|
||||||
|
let t1 = t0 + Duration::from_secs(1);
|
||||||
|
assert!(bucket.try_consume(30_000, t1).is_ok());
|
||||||
|
// 40_000 is more than the 32_000 refilled.
|
||||||
|
assert!(bucket.try_consume(40_000, t1).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_sustained_rate_balanced() {
|
||||||
|
let mut bucket = TokenBucket::new(1_000_000, 32_000);
|
||||||
|
let t0 = Instant::now();
|
||||||
|
// Send 32 KB every second for 5 seconds — exactly at refill rate.
|
||||||
|
// The bucket should never empty because each second it refills
|
||||||
|
// exactly what was consumed.
|
||||||
|
for i in 0..5 {
|
||||||
|
let t = t0 + Duration::from_secs(i);
|
||||||
|
assert!(
|
||||||
|
bucket.try_consume(32_000, t).is_ok(),
|
||||||
|
"32 KB/s sustained should stay within bucket limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conformance_tier_e_integration() {
|
||||||
|
// Use Opus64k (high bitrate ceiling + high payload bound) so Tiers
|
||||||
|
// A/B/D never fire on the small bursts used here. Only Tier E.
|
||||||
|
let mut meter = ConformanceMeter::with_token_bucket(TokenBucket::new(1_000, 500));
|
||||||
|
let header = make_header(CodecId::Opus64k);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Two 500-byte (wire) packets = 1_000 bytes — exactly the bucket cap.
|
||||||
|
assert!(
|
||||||
|
meter
|
||||||
|
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
|
||||||
|
.is_ok()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
meter
|
||||||
|
.observe(&header, 500 - MediaHeader::WIRE_SIZE, now)
|
||||||
|
.is_ok()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Third packet exceeds the 1_000-byte cap.
|
||||||
|
let result = meter.observe(&header, 10, now);
|
||||||
|
assert_eq!(result, Err(Violation::RateCapExceeded));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,16 +25,13 @@ pub struct Event {
|
|||||||
pub src: Option<String>,
|
pub src: Option<String>,
|
||||||
/// Packet sequence number.
|
/// Packet sequence number.
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub seq: Option<u16>,
|
pub seq: Option<u32>,
|
||||||
/// Codec identifier.
|
/// Codec identifier.
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub codec: Option<String>,
|
pub codec: Option<String>,
|
||||||
/// FEC block ID.
|
/// FEC block ID (low byte) and symbol index (high byte).
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub fec_block: Option<u8>,
|
pub fec_block: Option<u16>,
|
||||||
/// FEC symbol index.
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub fec_sym: Option<u8>,
|
|
||||||
/// Is FEC repair packet.
|
/// Is FEC repair packet.
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub repair: Option<bool>,
|
pub repair: Option<bool>,
|
||||||
@@ -60,7 +57,9 @@ pub struct Event {
|
|||||||
|
|
||||||
impl Event {
|
impl Event {
|
||||||
fn now() -> String {
|
fn now() -> String {
|
||||||
chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.6fZ").to_string()
|
chrono::Utc::now()
|
||||||
|
.format("%Y-%m-%dT%H:%M:%S%.6fZ")
|
||||||
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a minimal event with just type and timestamp.
|
/// Create a minimal event with just type and timestamp.
|
||||||
@@ -73,7 +72,6 @@ impl Event {
|
|||||||
seq: None,
|
seq: None,
|
||||||
codec: None,
|
codec: None,
|
||||||
fec_block: None,
|
fec_block: None,
|
||||||
fec_sym: None,
|
|
||||||
repair: None,
|
repair: None,
|
||||||
len: None,
|
len: None,
|
||||||
to_count: None,
|
to_count: None,
|
||||||
@@ -85,33 +83,59 @@ impl Event {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set room.
|
/// Set room.
|
||||||
pub fn room(mut self, room: &str) -> Self { self.room = Some(room.to_string()); self }
|
pub fn room(mut self, room: &str) -> Self {
|
||||||
|
self.room = Some(room.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set source.
|
/// Set source.
|
||||||
pub fn src(mut self, src: &str) -> Self { self.src = Some(src.to_string()); self }
|
pub fn src(mut self, src: &str) -> Self {
|
||||||
|
self.src = Some(src.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set packet header fields from a MediaPacket.
|
/// Set packet header fields from a MediaPacket.
|
||||||
pub fn packet(mut self, pkt: &wzp_proto::MediaPacket) -> Self {
|
pub fn packet(mut self, pkt: &wzp_proto::MediaPacket) -> Self {
|
||||||
self.seq = Some(pkt.header.seq);
|
self.seq = Some(pkt.header.seq);
|
||||||
self.codec = Some(format!("{:?}", pkt.header.codec_id));
|
self.codec = Some(format!("{:?}", pkt.header.codec_id));
|
||||||
self.fec_block = Some(pkt.header.fec_block);
|
self.fec_block = Some(pkt.header.fec_block);
|
||||||
self.fec_sym = Some(pkt.header.fec_symbol);
|
self.repair = Some(pkt.header.is_repair());
|
||||||
self.repair = Some(pkt.header.is_repair);
|
|
||||||
self.len = Some(pkt.payload.len());
|
self.len = Some(pkt.payload.len());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// Set seq only (when full packet not available).
|
/// Set seq only (when full packet not available).
|
||||||
pub fn seq(mut self, seq: u16) -> Self { self.seq = Some(seq); self }
|
pub fn seq(mut self, seq: u32) -> Self {
|
||||||
|
self.seq = Some(seq);
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set payload length.
|
/// Set payload length.
|
||||||
pub fn len(mut self, len: usize) -> Self { self.len = Some(len); self }
|
pub fn len(mut self, len: usize) -> Self {
|
||||||
|
self.len = Some(len);
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set recipient count.
|
/// Set recipient count.
|
||||||
pub fn to_count(mut self, n: usize) -> Self { self.to_count = Some(n); self }
|
pub fn to_count(mut self, n: usize) -> Self {
|
||||||
|
self.to_count = Some(n);
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set peer label.
|
/// Set peer label.
|
||||||
pub fn peer(mut self, peer: &str) -> Self { self.peer = Some(peer.to_string()); self }
|
pub fn peer(mut self, peer: &str) -> Self {
|
||||||
|
self.peer = Some(peer.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set drop reason.
|
/// Set drop reason.
|
||||||
pub fn reason(mut self, reason: &str) -> Self { self.reason = Some(reason.to_string()); self }
|
pub fn reason(mut self, reason: &str) -> Self {
|
||||||
|
self.reason = Some(reason.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set presence action.
|
/// Set presence action.
|
||||||
pub fn action(mut self, action: &str) -> Self { self.action = Some(action.to_string()); self }
|
pub fn action(mut self, action: &str) -> Self {
|
||||||
|
self.action = Some(action.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
/// Set participant count.
|
/// Set participant count.
|
||||||
pub fn participants(mut self, n: usize) -> Self { self.participants = Some(n); self }
|
pub fn participants(mut self, n: usize) -> Self {
|
||||||
|
self.participants = Some(n);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle for emitting events. Cheap to clone.
|
/// Handle for emitting events. Cheap to clone.
|
||||||
@@ -181,8 +205,12 @@ async fn writer_task(path: PathBuf, mut rx: mpsc::UnboundedReceiver<Event>) {
|
|||||||
while let Some(event) = rx.recv().await {
|
while let Some(event) = rx.recv().await {
|
||||||
match serde_json::to_string(&event) {
|
match serde_json::to_string(&event) {
|
||||||
Ok(json) => {
|
Ok(json) => {
|
||||||
if writer.write_all(json.as_bytes()).await.is_err() { break; }
|
if writer.write_all(json.as_bytes()).await.is_err() {
|
||||||
if writer.write_all(b"\n").await.is_err() { break; }
|
break;
|
||||||
|
}
|
||||||
|
if writer.write_all(b"\n").await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
count += 1;
|
count += 1;
|
||||||
// Flush every 100 events
|
// Flush every 100 events
|
||||||
if count % 100 == 0 {
|
if count % 100 == 0 {
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ use std::sync::Arc;
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use sha2::{Sha256, Digest};
|
use sha2::{Digest, Sha256};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
use wzp_proto::{MediaTransport, SignalMessage};
|
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||||
use wzp_transport::QuinnTransport;
|
use wzp_transport::QuinnTransport;
|
||||||
|
|
||||||
use crate::config::{PeerConfig, TrustedConfig};
|
use crate::config::{PeerConfig, TrustedConfig};
|
||||||
@@ -56,13 +56,14 @@ impl Deduplicator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if this packet is a duplicate (already seen within TTL).
|
/// Returns true if this packet is a duplicate (already seen within TTL).
|
||||||
fn is_dup(&mut self, room_hash: &[u8; 8], seq: u16, extra: u64) -> bool {
|
fn is_dup(&mut self, room_hash: &[u8; 8], seq: u32, extra: u64) -> bool {
|
||||||
let key = u64::from_be_bytes(*room_hash) ^ (seq as u64) ^ extra;
|
let key = u64::from_be_bytes(*room_hash) ^ (seq as u64) ^ extra;
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
|
||||||
// Periodic cleanup (every ~256 packets)
|
// Periodic cleanup (every ~256 packets)
|
||||||
if self.entries.len() > 256 {
|
if self.entries.len() > 256 {
|
||||||
self.entries.retain(|_, ts| now.duration_since(*ts) < self.ttl);
|
self.entries
|
||||||
|
.retain(|_, ts| now.duration_since(*ts) < self.ttl);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ts) = self.entries.get(&key) {
|
if let Some(ts) = self.entries.get(&key) {
|
||||||
@@ -134,7 +135,7 @@ pub struct FederationManager {
|
|||||||
peers: Vec<PeerConfig>,
|
peers: Vec<PeerConfig>,
|
||||||
trusted: Vec<TrustedConfig>,
|
trusted: Vec<TrustedConfig>,
|
||||||
global_rooms: HashSet<String>,
|
global_rooms: HashSet<String>,
|
||||||
room_mgr: Arc<Mutex<RoomManager>>,
|
room_mgr: Arc<RoomManager>,
|
||||||
endpoint: quinn::Endpoint,
|
endpoint: quinn::Endpoint,
|
||||||
local_tls_fp: String,
|
local_tls_fp: String,
|
||||||
metrics: Arc<crate::metrics::RelayMetrics>,
|
metrics: Arc<crate::metrics::RelayMetrics>,
|
||||||
@@ -161,7 +162,7 @@ impl FederationManager {
|
|||||||
peers: Vec<PeerConfig>,
|
peers: Vec<PeerConfig>,
|
||||||
trusted: Vec<TrustedConfig>,
|
trusted: Vec<TrustedConfig>,
|
||||||
global_rooms: HashSet<String>,
|
global_rooms: HashSet<String>,
|
||||||
room_mgr: Arc<Mutex<RoomManager>>,
|
room_mgr: Arc<RoomManager>,
|
||||||
endpoint: quinn::Endpoint,
|
endpoint: quinn::Endpoint,
|
||||||
local_tls_fp: String,
|
local_tls_fp: String,
|
||||||
metrics: Arc<crate::metrics::RelayMetrics>,
|
metrics: Arc<crate::metrics::RelayMetrics>,
|
||||||
@@ -213,16 +214,22 @@ impl FederationManager {
|
|||||||
/// `origin_relay_fp` against its own fp and drops self-sourced
|
/// `origin_relay_fp` against its own fp and drops self-sourced
|
||||||
/// forwards.
|
/// forwards.
|
||||||
pub async fn broadcast_signal(&self, msg: &wzp_proto::SignalMessage) -> usize {
|
pub async fn broadcast_signal(&self, msg: &wzp_proto::SignalMessage) -> usize {
|
||||||
|
let peers: Vec<(String, String, Arc<QuinnTransport>)> = {
|
||||||
let links = self.peer_links.lock().await;
|
let links = self.peer_links.lock().await;
|
||||||
|
links
|
||||||
|
.iter()
|
||||||
|
.map(|(fp, l)| (fp.clone(), l.label.clone(), l.transport.clone()))
|
||||||
|
.collect()
|
||||||
|
}; // lock released
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for (fp, link) in links.iter() {
|
for (fp, label, transport) in &peers {
|
||||||
match link.transport.send_signal(msg).await {
|
match transport.send_signal(msg).await {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
count += 1;
|
count += 1;
|
||||||
tracing::debug!(peer = %link.label, %fp, "federation: broadcast signal ok");
|
tracing::debug!(peer = %label, %fp, "federation: broadcast signal ok");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(peer = %link.label, %fp, error = %e, "federation: broadcast signal failed");
|
tracing::warn!(peer = %label, %fp, error = %e, "federation: broadcast signal failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -243,10 +250,12 @@ impl FederationManager {
|
|||||||
msg: &wzp_proto::SignalMessage,
|
msg: &wzp_proto::SignalMessage,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let normalized = normalize_fp(peer_relay_fp);
|
let normalized = normalize_fp(peer_relay_fp);
|
||||||
|
let transport = {
|
||||||
let links = self.peer_links.lock().await;
|
let links = self.peer_links.lock().await;
|
||||||
match links.get(&normalized) {
|
links.get(&normalized).map(|l| l.transport.clone())
|
||||||
Some(link) => link
|
}; // lock released
|
||||||
.transport
|
match transport {
|
||||||
|
Some(t) => t
|
||||||
.send_signal(msg)
|
.send_signal(msg)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| format!("send to peer {normalized}: {e}")),
|
.map_err(|e| format!("send to peer {normalized}: {e}")),
|
||||||
@@ -295,9 +304,10 @@ impl FederationManager {
|
|||||||
return Some(room.to_string());
|
return Some(room.to_string());
|
||||||
}
|
}
|
||||||
// Hashed match (desktop clients hash room names for SNI privacy)
|
// Hashed match (desktop clients hash room names for SNI privacy)
|
||||||
self.global_rooms.iter().find(|name| {
|
self.global_rooms
|
||||||
wzp_crypto::hash_room_name(name) == room
|
.iter()
|
||||||
}).map(|s| s.to_string())
|
.find(|name| wzp_crypto::hash_room_name(name) == room)
|
||||||
|
.map(|s| s.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the canonical federation room hash for a room.
|
/// Get the canonical federation room hash for a room.
|
||||||
@@ -333,10 +343,7 @@ impl FederationManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Room event dispatcher
|
// Room event dispatcher
|
||||||
let room_events = {
|
let room_events = self.room_mgr.subscribe_events();
|
||||||
let mgr = self.room_mgr.lock().await;
|
|
||||||
mgr.subscribe_events()
|
|
||||||
};
|
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
handles.push(tokio::spawn(async move {
|
handles.push(tokio::spawn(async move {
|
||||||
run_room_event_dispatcher(this, room_events).await;
|
run_room_event_dispatcher(this, room_events).await;
|
||||||
@@ -369,7 +376,10 @@ impl FederationManager {
|
|||||||
|
|
||||||
/// Get all remote participants for a room from all peer links.
|
/// Get all remote participants for a room from all peer links.
|
||||||
/// Deduplicates by fingerprint (same participant may appear via multiple links).
|
/// Deduplicates by fingerprint (same participant may appear via multiple links).
|
||||||
pub async fn get_remote_participants(&self, room: &str) -> Vec<wzp_proto::packet::RoomParticipant> {
|
pub async fn get_remote_participants(
|
||||||
|
&self,
|
||||||
|
room: &str,
|
||||||
|
) -> Vec<wzp_proto::packet::RoomParticipant> {
|
||||||
let canonical = self.resolve_global_room(room);
|
let canonical = self.resolve_global_room(room);
|
||||||
let links = self.peer_links.lock().await;
|
let links = self.peer_links.lock().await;
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
@@ -405,21 +415,35 @@ impl FederationManager {
|
|||||||
/// the other room-tagged helpers and for future per-room-name logging
|
/// the other room-tagged helpers and for future per-room-name logging
|
||||||
/// or rate limiting; the body currently forwards on `room_hash` alone
|
/// or rate limiting; the body currently forwards on `room_hash` alone
|
||||||
/// because that's what the wire format carries.
|
/// because that's what the wire format carries.
|
||||||
pub async fn forward_to_peers(&self, _room_name: &str, room_hash: &[u8; 8], media_data: &Bytes) {
|
pub async fn forward_to_peers(
|
||||||
|
&self,
|
||||||
|
_room_name: &str,
|
||||||
|
room_hash: &[u8; 8],
|
||||||
|
media_data: &Bytes,
|
||||||
|
) {
|
||||||
|
let peers: Vec<(String, Arc<QuinnTransport>)> = {
|
||||||
let links = self.peer_links.lock().await;
|
let links = self.peer_links.lock().await;
|
||||||
if links.is_empty() {
|
if links.is_empty() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (_fp, link) in links.iter() {
|
links
|
||||||
|
.values()
|
||||||
|
.map(|l| (l.label.clone(), l.transport.clone()))
|
||||||
|
.collect()
|
||||||
|
}; // lock released
|
||||||
|
|
||||||
|
for (label, transport) in &peers {
|
||||||
let mut tagged = Vec::with_capacity(8 + media_data.len());
|
let mut tagged = Vec::with_capacity(8 + media_data.len());
|
||||||
tagged.extend_from_slice(room_hash);
|
tagged.extend_from_slice(room_hash);
|
||||||
tagged.extend_from_slice(media_data);
|
tagged.extend_from_slice(media_data);
|
||||||
match link.transport.send_raw_datagram(&tagged) {
|
match transport.send_raw_datagram(&tagged) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
self.metrics.federation_packets_forwarded
|
self.metrics
|
||||||
.with_label_values(&[&link.label, "out"]).inc();
|
.federation_packets_forwarded
|
||||||
|
.with_label_values(&[label, "out"])
|
||||||
|
.inc();
|
||||||
}
|
}
|
||||||
Err(e) => warn!(peer = %link.label, "federation send error: {e}"),
|
Err(e) => warn!(peer = %label, "federation send error: {e}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -427,20 +451,25 @@ impl FederationManager {
|
|||||||
// ── Trust verification (kept from previous implementation) ──
|
// ── Trust verification (kept from previous implementation) ──
|
||||||
|
|
||||||
pub fn find_peer_by_fingerprint(&self, fp: &str) -> Option<&PeerConfig> {
|
pub fn find_peer_by_fingerprint(&self, fp: &str) -> Option<&PeerConfig> {
|
||||||
self.peers.iter().find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp))
|
self.peers
|
||||||
|
.iter()
|
||||||
|
.find(|p| normalize_fp(&p.fingerprint) == normalize_fp(fp))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find_peer_by_addr(&self, addr: SocketAddr) -> Option<&PeerConfig> {
|
pub fn find_peer_by_addr(&self, addr: SocketAddr) -> Option<&PeerConfig> {
|
||||||
let addr_ip = addr.ip();
|
let addr_ip = addr.ip();
|
||||||
self.peers.iter().find(|p| {
|
self.peers.iter().find(|p| {
|
||||||
p.url.parse::<SocketAddr>()
|
p.url
|
||||||
|
.parse::<SocketAddr>()
|
||||||
.map(|sa| sa.ip() == addr_ip)
|
.map(|sa| sa.ip() == addr_ip)
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find_trusted_by_fingerprint(&self, fp: &str) -> Option<&TrustedConfig> {
|
pub fn find_trusted_by_fingerprint(&self, fp: &str) -> Option<&TrustedConfig> {
|
||||||
self.trusted.iter().find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp))
|
self.trusted
|
||||||
|
.iter()
|
||||||
|
.find(|t| normalize_fp(&t.fingerprint) == normalize_fp(fp))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_inbound_trust(&self, addr: SocketAddr, hello_fp: &str) -> Option<String> {
|
pub fn check_inbound_trust(&self, addr: SocketAddr, hello_fp: &str) -> Option<String> {
|
||||||
@@ -448,7 +477,12 @@ impl FederationManager {
|
|||||||
return Some(peer.label.clone().unwrap_or_else(|| peer.url.clone()));
|
return Some(peer.label.clone().unwrap_or_else(|| peer.url.clone()));
|
||||||
}
|
}
|
||||||
if let Some(trusted) = self.find_trusted_by_fingerprint(hello_fp) {
|
if let Some(trusted) = self.find_trusted_by_fingerprint(hello_fp) {
|
||||||
return Some(trusted.label.clone().unwrap_or_else(|| hello_fp[..16].to_string()));
|
return Some(
|
||||||
|
trusted
|
||||||
|
.label
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| hello_fp[..16].to_string()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -467,7 +501,8 @@ pub async fn run_federation_media_egress(
|
|||||||
if count == 1 || count % 250 == 0 {
|
if count == 1 || count % 250 == 0 {
|
||||||
info!(room = %out.room_name, count, "federation egress: forwarding media");
|
info!(room = %out.room_name, count, "federation egress: forwarding media");
|
||||||
}
|
}
|
||||||
fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data).await;
|
fm.forward_to_peers(&out.room_name, &out.room_hash, &out.data)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
info!(total = count, "federation egress task ended");
|
info!(total = count, "federation egress task ended");
|
||||||
}
|
}
|
||||||
@@ -483,25 +518,35 @@ async fn run_room_event_dispatcher(
|
|||||||
match events.recv().await {
|
match events.recv().await {
|
||||||
Ok(RoomEvent::LocalJoin { room }) => {
|
Ok(RoomEvent::LocalJoin { room }) => {
|
||||||
if fm.is_global_room(&room) {
|
if fm.is_global_room(&room) {
|
||||||
let participants = {
|
let participants = fm.room_mgr.local_participant_list(&room);
|
||||||
let mgr = fm.room_mgr.lock().await;
|
|
||||||
mgr.local_participant_list(&room)
|
|
||||||
};
|
|
||||||
info!(room = %room, count = participants.len(), "global room now active, announcing to peers");
|
info!(room = %room, count = participants.len(), "global room now active, announcing to peers");
|
||||||
let msg = SignalMessage::GlobalRoomActive { room, participants };
|
let msg = SignalMessage::GlobalRoomActive {
|
||||||
|
version: default_signal_version(),
|
||||||
|
room,
|
||||||
|
participants,
|
||||||
|
};
|
||||||
|
let transports: Vec<Arc<QuinnTransport>> = {
|
||||||
let links = fm.peer_links.lock().await;
|
let links = fm.peer_links.lock().await;
|
||||||
for link in links.values() {
|
links.values().map(|l| l.transport.clone()).collect()
|
||||||
let _ = link.transport.send_signal(&msg).await;
|
};
|
||||||
|
for t in &transports {
|
||||||
|
let _ = t.send_signal(&msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(RoomEvent::LocalLeave { room }) => {
|
Ok(RoomEvent::LocalLeave { room }) => {
|
||||||
if fm.is_global_room(&room) {
|
if fm.is_global_room(&room) {
|
||||||
info!(room = %room, "global room now inactive, announcing to peers");
|
info!(room = %room, "global room now inactive, announcing to peers");
|
||||||
let msg = SignalMessage::GlobalRoomInactive { room };
|
let msg = SignalMessage::GlobalRoomInactive {
|
||||||
|
version: default_signal_version(),
|
||||||
|
room,
|
||||||
|
};
|
||||||
|
let transports: Vec<Arc<QuinnTransport>> = {
|
||||||
let links = fm.peer_links.lock().await;
|
let links = fm.peer_links.lock().await;
|
||||||
for link in links.values() {
|
links.values().map(|l| l.transport.clone()).collect()
|
||||||
let _ = link.transport.send_signal(&msg).await;
|
};
|
||||||
|
for t in &transports {
|
||||||
|
let _ = t.send_signal(&msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,7 +574,9 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
|
|||||||
let links = fm.peer_links.lock().await;
|
let links = fm.peer_links.lock().await;
|
||||||
let mut stale = Vec::new();
|
let mut stale = Vec::new();
|
||||||
for (fp, link) in links.iter() {
|
for (fp, link) in links.iter() {
|
||||||
if link.last_seen.elapsed() > stale_threshold && !link.remote_participants.is_empty() {
|
if link.last_seen.elapsed() > stale_threshold
|
||||||
|
&& !link.remote_participants.is_empty()
|
||||||
|
{
|
||||||
for room in link.remote_participants.keys() {
|
for room in link.remote_participants.keys() {
|
||||||
stale.push((fp.clone(), room.clone()));
|
stale.push((fp.clone(), room.clone()));
|
||||||
}
|
}
|
||||||
@@ -560,20 +607,20 @@ async fn run_stale_presence_sweeper(fm: Arc<FederationManager>) {
|
|||||||
|
|
||||||
// Broadcast updated RoomUpdate for affected rooms
|
// Broadcast updated RoomUpdate for affected rooms
|
||||||
for room in &affected_rooms {
|
for room in &affected_rooms {
|
||||||
let mgr = fm.room_mgr.lock().await;
|
let active = fm.room_mgr.active_rooms();
|
||||||
for local_room in mgr.active_rooms() {
|
for local_room in &active {
|
||||||
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(room) {
|
if fm.resolve_global_room(local_room) == fm.resolve_global_room(room) {
|
||||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||||
let remote = fm.get_remote_participants(&local_room).await;
|
let remote = fm.get_remote_participants(local_room).await;
|
||||||
all_participants.extend(remote);
|
all_participants.extend(remote);
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||||
let update = SignalMessage::RoomUpdate {
|
let update = SignalMessage::RoomUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
count: all_participants.len() as u32,
|
count: all_participants.len() as u32,
|
||||||
participants: all_participants,
|
participants: all_participants,
|
||||||
};
|
};
|
||||||
let senders = mgr.local_senders(&local_room);
|
let senders = fm.room_mgr.local_senders(local_room);
|
||||||
drop(mgr);
|
|
||||||
room::broadcast_signal(&senders, &update).await;
|
room::broadcast_signal(&senders, &update).await;
|
||||||
info!(room = %room, "swept stale presence — broadcast updated RoomUpdate");
|
info!(room = %room, "swept stale presence — broadcast updated RoomUpdate");
|
||||||
break;
|
break;
|
||||||
@@ -609,7 +656,10 @@ async fn run_peer_loop(fm: Arc<FederationManager>, peer: PeerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to a peer relay and send hello.
|
/// Connect to a peer relay and send hello.
|
||||||
async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Arc<QuinnTransport>, anyhow::Error> {
|
async fn connect_to_peer(
|
||||||
|
fm: &FederationManager,
|
||||||
|
peer: &PeerConfig,
|
||||||
|
) -> Result<Arc<QuinnTransport>, anyhow::Error> {
|
||||||
let addr: SocketAddr = peer.url.parse()?;
|
let addr: SocketAddr = peer.url.parse()?;
|
||||||
let client_cfg = wzp_transport::client_config();
|
let client_cfg = wzp_transport::client_config();
|
||||||
let conn = wzp_transport::connect(&fm.endpoint, addr, "_federation", client_cfg).await?;
|
let conn = wzp_transport::connect(&fm.endpoint, addr, "_federation", client_cfg).await?;
|
||||||
@@ -617,9 +667,12 @@ async fn connect_to_peer(fm: &FederationManager, peer: &PeerConfig) -> Result<Ar
|
|||||||
|
|
||||||
// Send hello with our TLS fingerprint
|
// Send hello with our TLS fingerprint
|
||||||
let hello = SignalMessage::FederationHello {
|
let hello = SignalMessage::FederationHello {
|
||||||
|
version: default_signal_version(),
|
||||||
tls_fingerprint: fm.local_tls_fp.clone(),
|
tls_fingerprint: fm.local_tls_fp.clone(),
|
||||||
};
|
};
|
||||||
transport.send_signal(&hello).await
|
transport
|
||||||
|
.send_signal(&hello)
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("federation hello send failed: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("federation hello send failed: {e}"))?;
|
||||||
|
|
||||||
info!(peer_url = %peer.url, label = ?peer.label, "federation: connected (hello sent)");
|
info!(peer_url = %peer.url, label = ?peer.label, "federation: connected (hello sent)");
|
||||||
@@ -636,31 +689,40 @@ async fn run_federation_link(
|
|||||||
peer_label: String,
|
peer_label: String,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), anyhow::Error> {
|
||||||
// Register peer link + metrics
|
// Register peer link + metrics
|
||||||
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(1);
|
fm.metrics
|
||||||
|
.federation_peer_status
|
||||||
|
.with_label_values(&[&peer_label])
|
||||||
|
.set(1);
|
||||||
{
|
{
|
||||||
let mut links = fm.peer_links.lock().await;
|
let mut links = fm.peer_links.lock().await;
|
||||||
links.insert(peer_fp.clone(), PeerLink {
|
links.insert(
|
||||||
|
peer_fp.clone(),
|
||||||
|
PeerLink {
|
||||||
transport: transport.clone(),
|
transport: transport.clone(),
|
||||||
label: peer_label.clone(),
|
label: peer_label.clone(),
|
||||||
active_rooms: HashSet::new(),
|
active_rooms: HashSet::new(),
|
||||||
remote_participants: HashMap::new(),
|
remote_participants: HashMap::new(),
|
||||||
last_seen: Instant::now(),
|
last_seen: Instant::now(),
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Announce our currently active global rooms to this new peer
|
// Announce our currently active global rooms to this new peer
|
||||||
// Collect all announcements first, then send (avoid holding locks across await)
|
// Collect all announcements first, then send (avoid holding locks across await)
|
||||||
let announcements = {
|
let announcements = {
|
||||||
let mgr = fm.room_mgr.lock().await;
|
let active = fm.room_mgr.active_rooms();
|
||||||
let active = mgr.active_rooms();
|
|
||||||
let mut msgs = Vec::new();
|
let mut msgs = Vec::new();
|
||||||
|
|
||||||
// Local rooms
|
// Local rooms
|
||||||
for room_name in &active {
|
for room_name in &active {
|
||||||
if fm.is_global_room(room_name) {
|
if fm.is_global_room(room_name) {
|
||||||
let participants = mgr.local_participant_list(room_name);
|
let participants = fm.room_mgr.local_participant_list(room_name);
|
||||||
info!(peer = %peer_label, room = %room_name, participants = participants.len(), "announcing local global room to new peer");
|
info!(peer = %peer_label, room = %room_name, participants = participants.len(), "announcing local global room to new peer");
|
||||||
msgs.push(SignalMessage::GlobalRoomActive { room: room_name.clone(), participants });
|
msgs.push(SignalMessage::GlobalRoomActive {
|
||||||
|
version: default_signal_version(),
|
||||||
|
room: room_name.clone(),
|
||||||
|
participants,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -672,6 +734,7 @@ async fn run_federation_link(
|
|||||||
if fm.is_global_room(room) {
|
if fm.is_global_room(room) {
|
||||||
info!(peer = %peer_label, room = %room, via = %link.label, "propagating remote room to new peer");
|
info!(peer = %peer_label, room = %room, via = %link.label, "propagating remote room to new peer");
|
||||||
msgs.push(SignalMessage::GlobalRoomActive {
|
msgs.push(SignalMessage::GlobalRoomActive {
|
||||||
|
version: default_signal_version(),
|
||||||
room: room.clone(),
|
room: room.clone(),
|
||||||
participants: participants.clone(),
|
participants: participants.clone(),
|
||||||
});
|
});
|
||||||
@@ -756,7 +819,10 @@ async fn run_federation_link(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup: remove peer link + metrics
|
// Cleanup: remove peer link + metrics
|
||||||
fm.metrics.federation_peer_status.with_label_values(&[&peer_label]).set(0);
|
fm.metrics
|
||||||
|
.federation_peer_status
|
||||||
|
.with_label_values(&[&peer_label])
|
||||||
|
.set(0);
|
||||||
{
|
{
|
||||||
let mut links = fm.peer_links.lock().await;
|
let mut links = fm.peer_links.lock().await;
|
||||||
links.remove(&peer_fp);
|
links.remove(&peer_fp);
|
||||||
@@ -782,7 +848,9 @@ async fn handle_signal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
SignalMessage::GlobalRoomActive { room, participants } => {
|
SignalMessage::GlobalRoomActive {
|
||||||
|
room, participants, ..
|
||||||
|
} => {
|
||||||
if fm.is_global_room(&room) {
|
if fm.is_global_room(&room) {
|
||||||
info!(peer = %peer_label, room = %room, remote_participants = participants.len(), "peer has global room active");
|
info!(peer = %peer_label, room = %room, remote_participants = participants.len(), "peer has global room active");
|
||||||
let mut links = fm.peer_links.lock().await;
|
let mut links = fm.peer_links.lock().await;
|
||||||
@@ -794,76 +862,93 @@ async fn handle_signal(
|
|||||||
fm.metrics.federation_active_rooms.set(total as i64);
|
fm.metrics.federation_active_rooms.set(total as i64);
|
||||||
if let Some(link) = links.get_mut(peer_fp) {
|
if let Some(link) = links.get_mut(peer_fp) {
|
||||||
// Tag remote participants with their relay label
|
// Tag remote participants with their relay label
|
||||||
let tagged: Vec<_> = participants.iter().map(|p| {
|
let tagged: Vec<_> = participants
|
||||||
|
.iter()
|
||||||
|
.map(|p| {
|
||||||
let mut tagged = p.clone();
|
let mut tagged = p.clone();
|
||||||
if tagged.relay_label.is_none() {
|
if tagged.relay_label.is_none() {
|
||||||
tagged.relay_label = Some(link.label.clone());
|
tagged.relay_label = Some(link.label.clone());
|
||||||
}
|
}
|
||||||
tagged
|
tagged
|
||||||
}).collect();
|
})
|
||||||
|
.collect();
|
||||||
link.remote_participants.insert(room.clone(), tagged);
|
link.remote_participants.insert(room.clone(), tagged);
|
||||||
}
|
}
|
||||||
// Propagate to other peers (with relay labels preserved)
|
// Propagate to other peers (with relay labels preserved)
|
||||||
let tagged_for_propagation = if let Some(link) = links.get(peer_fp) {
|
let tagged_for_propagation = if let Some(link) = links.get(peer_fp) {
|
||||||
let label = link.label.clone();
|
let label = link.label.clone();
|
||||||
participants.iter().map(|p| {
|
participants
|
||||||
|
.iter()
|
||||||
|
.map(|p| {
|
||||||
let mut t = p.clone();
|
let mut t = p.clone();
|
||||||
if t.relay_label.is_none() {
|
if t.relay_label.is_none() {
|
||||||
t.relay_label = Some(label.clone());
|
t.relay_label = Some(label.clone());
|
||||||
}
|
}
|
||||||
t
|
t
|
||||||
}).collect::<Vec<_>>()
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
} else {
|
} else {
|
||||||
participants.clone()
|
participants.clone()
|
||||||
};
|
};
|
||||||
for (fp, link) in links.iter() {
|
for (fp, link) in links.iter() {
|
||||||
if fp != peer_fp {
|
if fp != peer_fp {
|
||||||
let _ = link.transport.send_signal(&SignalMessage::GlobalRoomActive {
|
let _ = link
|
||||||
|
.transport
|
||||||
|
.send_signal(&SignalMessage::GlobalRoomActive {
|
||||||
|
version: default_signal_version(),
|
||||||
room: room.clone(),
|
room: room.clone(),
|
||||||
participants: tagged_for_propagation.clone(),
|
participants: tagged_for_propagation.clone(),
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
drop(links);
|
drop(links);
|
||||||
|
|
||||||
// Broadcast updated RoomUpdate to local clients in this room
|
// Broadcast updated RoomUpdate to local clients in this room
|
||||||
// Find the local room name (may be hashed or raw)
|
// Find the local room name (may be hashed or raw)
|
||||||
let mgr = fm.room_mgr.lock().await;
|
let active = fm.room_mgr.active_rooms();
|
||||||
for local_room in mgr.active_rooms() {
|
for local_room in &active {
|
||||||
if fm.is_global_room(&local_room) && fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
if fm.is_global_room(local_room)
|
||||||
|
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
|
||||||
|
{
|
||||||
// Build merged participant list: local + all remote (deduped)
|
// Build merged participant list: local + all remote (deduped)
|
||||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||||
|
{
|
||||||
let links = fm.peer_links.lock().await;
|
let links = fm.peer_links.lock().await;
|
||||||
for link in links.values() {
|
for link in links.values() {
|
||||||
if let Some(ref canonical) = fm.resolve_global_room(&local_room) {
|
if let Some(ref canonical) = fm.resolve_global_room(local_room) {
|
||||||
if let Some(remote) = link.remote_participants.get(canonical.as_str()) {
|
if let Some(remote) =
|
||||||
|
link.remote_participants.get(canonical.as_str())
|
||||||
|
{
|
||||||
all_participants.extend(remote.iter().cloned());
|
all_participants.extend(remote.iter().cloned());
|
||||||
}
|
}
|
||||||
// Also check raw room name, but only if different from canonical
|
// Also check raw room name, but only if different from canonical
|
||||||
if canonical != &local_room {
|
if canonical != local_room {
|
||||||
if let Some(remote) = link.remote_participants.get(&local_room) {
|
if let Some(remote) =
|
||||||
|
link.remote_participants.get(local_room)
|
||||||
|
{
|
||||||
all_participants.extend(remote.iter().cloned());
|
all_participants.extend(remote.iter().cloned());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// Deduplicate by fingerprint
|
// Deduplicate by fingerprint
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||||
let update = SignalMessage::RoomUpdate {
|
let update = SignalMessage::RoomUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
count: all_participants.len() as u32,
|
count: all_participants.len() as u32,
|
||||||
participants: all_participants,
|
participants: all_participants,
|
||||||
};
|
};
|
||||||
let senders = mgr.local_senders(&local_room);
|
let senders = fm.room_mgr.local_senders(local_room);
|
||||||
drop(links);
|
|
||||||
drop(mgr);
|
|
||||||
room::broadcast_signal(&senders, &update).await;
|
room::broadcast_signal(&senders, &update).await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SignalMessage::GlobalRoomInactive { room } => {
|
SignalMessage::GlobalRoomInactive { room, .. } => {
|
||||||
info!(peer = %peer_label, room = %room, "peer global room now inactive");
|
info!(peer = %peer_label, room = %room, "peer global room now inactive");
|
||||||
let mut links = fm.peer_links.lock().await;
|
let mut links = fm.peer_links.lock().await;
|
||||||
if let Some(link) = links.get_mut(peer_fp) {
|
if let Some(link) = links.get_mut(peer_fp) {
|
||||||
@@ -885,7 +970,9 @@ async fn handle_signal(
|
|||||||
let canonical = fm.resolve_global_room(&room);
|
let canonical = fm.resolve_global_room(&room);
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for (fp, link) in links.iter() {
|
for (fp, link) in links.iter() {
|
||||||
if fp == peer_fp { continue; }
|
if fp == peer_fp {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if let Some(ref c) = canonical {
|
if let Some(ref c) = canonical {
|
||||||
if let Some(remote) = link.remote_participants.get(c.as_str()) {
|
if let Some(remote) = link.remote_participants.get(c.as_str()) {
|
||||||
result.extend(remote.iter().cloned());
|
result.extend(remote.iter().cloned());
|
||||||
@@ -899,14 +986,16 @@ async fn handle_signal(
|
|||||||
|
|
||||||
// Propagate to other peers: send updated GlobalRoomActive with revised list,
|
// Propagate to other peers: send updated GlobalRoomActive with revised list,
|
||||||
// or GlobalRoomInactive if no participants remain anywhere
|
// or GlobalRoomInactive if no participants remain anywhere
|
||||||
let local_active = {
|
let local_active = fm
|
||||||
let mgr = fm.room_mgr.lock().await;
|
.room_mgr
|
||||||
mgr.active_rooms().iter().any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room))
|
.active_rooms()
|
||||||
};
|
.iter()
|
||||||
|
.any(|r| fm.resolve_global_room(r) == fm.resolve_global_room(&room));
|
||||||
let has_remaining = !remaining_remote.is_empty() || local_active;
|
let has_remaining = !remaining_remote.is_empty() || local_active;
|
||||||
|
|
||||||
// Collect peer transports to send to (avoid holding lock across await)
|
// Collect peer transports to send to (avoid holding lock across await)
|
||||||
let peer_sends: Vec<_> = links.iter()
|
let peer_sends: Vec<_> = links
|
||||||
|
.iter()
|
||||||
.filter(|(fp, _)| *fp != peer_fp)
|
.filter(|(fp, _)| *fp != peer_fp)
|
||||||
.map(|(_, link)| link.transport.clone())
|
.map(|(_, link)| link.transport.clone())
|
||||||
.collect();
|
.collect();
|
||||||
@@ -916,15 +1005,16 @@ async fn handle_signal(
|
|||||||
// Send updated participant list to other peers
|
// Send updated participant list to other peers
|
||||||
let mut updated_participants = remaining_remote.clone();
|
let mut updated_participants = remaining_remote.clone();
|
||||||
if local_active {
|
if local_active {
|
||||||
let mgr = fm.room_mgr.lock().await;
|
for local_room in fm.room_mgr.active_rooms() {
|
||||||
for local_room in mgr.active_rooms() {
|
|
||||||
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
if fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
||||||
updated_participants.extend(mgr.local_participant_list(&local_room));
|
updated_participants
|
||||||
|
.extend(fm.room_mgr.local_participant_list(&local_room));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let msg = SignalMessage::GlobalRoomActive {
|
let msg = SignalMessage::GlobalRoomActive {
|
||||||
|
version: default_signal_version(),
|
||||||
room: room.clone(),
|
room: room.clone(),
|
||||||
participants: updated_participants,
|
participants: updated_participants,
|
||||||
};
|
};
|
||||||
@@ -933,27 +1023,32 @@ async fn handle_signal(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No participants left anywhere — propagate inactive
|
// No participants left anywhere — propagate inactive
|
||||||
let msg = SignalMessage::GlobalRoomInactive { room: room.clone() };
|
let msg = SignalMessage::GlobalRoomInactive {
|
||||||
|
version: default_signal_version(),
|
||||||
|
room: room.clone(),
|
||||||
|
};
|
||||||
for transport in &peer_sends {
|
for transport in &peer_sends {
|
||||||
let _ = transport.send_signal(&msg).await;
|
let _ = transport.send_signal(&msg).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast updated RoomUpdate to local clients (remote participant removed)
|
// Broadcast updated RoomUpdate to local clients (remote participant removed)
|
||||||
let mgr = fm.room_mgr.lock().await;
|
let active = fm.room_mgr.active_rooms();
|
||||||
for local_room in mgr.active_rooms() {
|
for local_room in &active {
|
||||||
if fm.is_global_room(&local_room) && fm.resolve_global_room(&local_room) == fm.resolve_global_room(&room) {
|
if fm.is_global_room(local_room)
|
||||||
let mut all_participants = mgr.local_participant_list(&local_room);
|
&& fm.resolve_global_room(local_room) == fm.resolve_global_room(&room)
|
||||||
|
{
|
||||||
|
let mut all_participants = fm.room_mgr.local_participant_list(local_room);
|
||||||
all_participants.extend(remaining_remote.iter().cloned());
|
all_participants.extend(remaining_remote.iter().cloned());
|
||||||
// Deduplicate by fingerprint
|
// Deduplicate by fingerprint
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
all_participants.retain(|p| seen.insert(p.fingerprint.clone()));
|
||||||
let update = SignalMessage::RoomUpdate {
|
let update = SignalMessage::RoomUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
count: all_participants.len() as u32,
|
count: all_participants.len() as u32,
|
||||||
participants: all_participants,
|
participants: all_participants,
|
||||||
};
|
};
|
||||||
let senders = mgr.local_senders(&local_room);
|
let senders = fm.room_mgr.local_senders(local_room);
|
||||||
drop(mgr);
|
|
||||||
room::broadcast_signal(&senders, &update).await;
|
room::broadcast_signal(&senders, &update).await;
|
||||||
info!(room = %room, "broadcast updated presence (remote participant removed)");
|
info!(room = %room, "broadcast updated presence (remote participant removed)");
|
||||||
break;
|
break;
|
||||||
@@ -972,7 +1067,11 @@ async fn handle_signal(
|
|||||||
// Loop prevention: drop any forward whose origin matches
|
// Loop prevention: drop any forward whose origin matches
|
||||||
// our own federation TLS fingerprint. With
|
// our own federation TLS fingerprint. With
|
||||||
// broadcast-to-all-peers this prevents A→B→A echo loops.
|
// broadcast-to-all-peers this prevents A→B→A echo loops.
|
||||||
SignalMessage::FederatedSignalForward { inner, origin_relay_fp } => {
|
SignalMessage::FederatedSignalForward {
|
||||||
|
inner,
|
||||||
|
origin_relay_fp,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
if origin_relay_fp == fm.local_tls_fp {
|
if origin_relay_fp == fm.local_tls_fp {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
peer = %peer_label,
|
peer = %peer_label,
|
||||||
@@ -1016,12 +1115,10 @@ async fn handle_signal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle an incoming federation datagram (room-hash-tagged media).
|
/// Handle an incoming federation datagram (room-hash-tagged media).
|
||||||
async fn handle_datagram(
|
async fn handle_datagram(fm: &Arc<FederationManager>, source_peer_fp: &str, data: Bytes) {
|
||||||
fm: &Arc<FederationManager>,
|
if data.len() < 12 {
|
||||||
source_peer_fp: &str,
|
return;
|
||||||
data: Bytes,
|
} // 8-byte hash + min packet
|
||||||
) {
|
|
||||||
if data.len() < 12 { return; } // 8-byte hash + min packet
|
|
||||||
|
|
||||||
let mut rh = [0u8; 8];
|
let mut rh = [0u8; 8];
|
||||||
rh.copy_from_slice(&data[..8]);
|
rh.copy_from_slice(&data[..8]);
|
||||||
@@ -1030,7 +1127,8 @@ async fn handle_datagram(
|
|||||||
let pkt = match wzp_proto::MediaPacket::from_bytes(media_bytes.clone()) {
|
let pkt = match wzp_proto::MediaPacket::from_bytes(media_bytes.clone()) {
|
||||||
Some(pkt) => pkt,
|
Some(pkt) => pkt,
|
||||||
None => {
|
None => {
|
||||||
fm.event_log.emit(Event::new("federation_ingress_malformed").len(data.len()));
|
fm.event_log
|
||||||
|
.emit(Event::new("federation_ingress_malformed").len(data.len()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1038,13 +1136,22 @@ async fn handle_datagram(
|
|||||||
// Event log: federation ingress
|
// Event log: federation ingress
|
||||||
let peer_label = {
|
let peer_label = {
|
||||||
let links = fm.peer_links.lock().await;
|
let links = fm.peer_links.lock().await;
|
||||||
links.get(source_peer_fp).map(|l| l.label.clone()).unwrap_or_default()
|
links
|
||||||
|
.get(source_peer_fp)
|
||||||
|
.map(|l| l.label.clone())
|
||||||
|
.unwrap_or_default()
|
||||||
};
|
};
|
||||||
fm.event_log.emit(Event::new("federation_ingress").packet(&pkt).peer(&peer_label));
|
fm.event_log.emit(
|
||||||
|
Event::new("federation_ingress")
|
||||||
|
.packet(&pkt)
|
||||||
|
.peer(&peer_label),
|
||||||
|
);
|
||||||
|
|
||||||
// Count inbound federation packet + update last_seen
|
// Count inbound federation packet + update last_seen
|
||||||
fm.metrics.federation_packets_forwarded
|
fm.metrics
|
||||||
.with_label_values(&[source_peer_fp, "in"]).inc();
|
.federation_packets_forwarded
|
||||||
|
.with_label_values(&[source_peer_fp, "in"])
|
||||||
|
.inc();
|
||||||
{
|
{
|
||||||
let mut links = fm.peer_links.lock().await;
|
let mut links = fm.peer_links.lock().await;
|
||||||
if let Some(link) = links.get_mut(source_peer_fp) {
|
if let Some(link) = links.get_mut(source_peer_fp) {
|
||||||
@@ -1065,38 +1172,53 @@ async fn handle_datagram(
|
|||||||
{
|
{
|
||||||
let mut dedup = fm.dedup.lock().await;
|
let mut dedup = fm.dedup.lock().await;
|
||||||
if dedup.is_dup(&rh, pkt.header.seq, payload_hash) {
|
if dedup.is_dup(&rh, pkt.header.seq, payload_hash) {
|
||||||
fm.event_log.emit(Event::new("dedup_drop").seq(pkt.header.seq).peer(&peer_label));
|
fm.event_log.emit(
|
||||||
|
Event::new("dedup_drop")
|
||||||
|
.seq(pkt.header.seq)
|
||||||
|
.peer(&peer_label),
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find room by hash — check local rooms AND global room config
|
// Find room by hash -- check local rooms AND global room config
|
||||||
let room_name = {
|
let room_name = {
|
||||||
let mgr = fm.room_mgr.lock().await;
|
let active = fm.room_mgr.active_rooms();
|
||||||
let active = mgr.active_rooms();
|
|
||||||
// First: check local rooms (has participants)
|
// First: check local rooms (has participants)
|
||||||
active.iter().find(|r| room_hash(r) == rh).cloned()
|
active
|
||||||
.or_else(|| active.iter().find(|r| fm.global_room_hash(r) == rh).cloned())
|
.iter()
|
||||||
|
.find(|r| room_hash(r) == rh)
|
||||||
|
.cloned()
|
||||||
|
.or_else(|| {
|
||||||
|
active
|
||||||
|
.iter()
|
||||||
|
.find(|r| fm.global_room_hash(r) == rh)
|
||||||
|
.cloned()
|
||||||
|
})
|
||||||
// Second: check static global room config (hub relay may have no local participants)
|
// Second: check static global room config (hub relay may have no local participants)
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
fm.global_rooms.iter().find(|name| room_hash(name) == rh).cloned()
|
fm.global_rooms
|
||||||
|
.iter()
|
||||||
|
.find(|name| room_hash(name) == rh)
|
||||||
|
.cloned()
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
let room_name = match room_name {
|
let room_name = match room_name {
|
||||||
Some(r) => r,
|
Some(r) => r,
|
||||||
None => {
|
None => {
|
||||||
fm.event_log.emit(Event::new("room_not_found").seq(pkt.header.seq).peer(&peer_label));
|
fm.event_log.emit(
|
||||||
|
Event::new("room_not_found")
|
||||||
|
.seq(pkt.header.seq)
|
||||||
|
.peer(&peer_label),
|
||||||
|
);
|
||||||
// Phase 4.1 diagnostic: log the hash + active rooms
|
// Phase 4.1 diagnostic: log the hash + active rooms
|
||||||
// so we can diagnose cross-relay call-* media routing
|
// so we can diagnose cross-relay call-* media routing
|
||||||
// failures. This fires when a peer relay sends media
|
// failures. This fires when a peer relay sends media
|
||||||
// for a room we don't have locally — could be a
|
// for a room we don't have locally — could be a
|
||||||
// timing issue (peer joined before us) or a hash
|
// timing issue (peer joined before us) or a hash
|
||||||
// mismatch.
|
// mismatch.
|
||||||
let active = {
|
let active = fm.room_mgr.active_rooms();
|
||||||
let mgr = fm.room_mgr.lock().await;
|
|
||||||
mgr.active_rooms()
|
|
||||||
};
|
|
||||||
warn!(
|
warn!(
|
||||||
room_hash = ?rh,
|
room_hash = ?rh,
|
||||||
active_rooms = ?active,
|
active_rooms = ?active,
|
||||||
@@ -1111,32 +1233,46 @@ async fn handle_datagram(
|
|||||||
// Rate limit per room
|
// Rate limit per room
|
||||||
if FEDERATION_RATE_LIMIT_PPS > 0 {
|
if FEDERATION_RATE_LIMIT_PPS > 0 {
|
||||||
let mut limiters = fm.rate_limiters.lock().await;
|
let mut limiters = fm.rate_limiters.lock().await;
|
||||||
let limiter = limiters.entry(room_name.clone())
|
let limiter = limiters
|
||||||
|
.entry(room_name.clone())
|
||||||
.or_insert_with(|| RateLimiter::new(FEDERATION_RATE_LIMIT_PPS));
|
.or_insert_with(|| RateLimiter::new(FEDERATION_RATE_LIMIT_PPS));
|
||||||
if !limiter.allow() {
|
if !limiter.allow() {
|
||||||
fm.event_log.emit(Event::new("rate_limit_drop").room(&room_name).seq(pkt.header.seq));
|
fm.event_log.emit(
|
||||||
|
Event::new("rate_limit_drop")
|
||||||
|
.room(&room_name)
|
||||||
|
.seq(pkt.header.seq),
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deliver to all local participants — forward the raw bytes as-is.
|
// Deliver to all local participants — forward the raw bytes as-is.
|
||||||
// The original sender's MediaPacket is preserved exactly (no re-serialization).
|
// The original sender's MediaPacket is preserved exactly (no re-serialization).
|
||||||
let locals = {
|
let locals = fm.room_mgr.local_senders(&room_name);
|
||||||
let mgr = fm.room_mgr.lock().await;
|
|
||||||
mgr.local_senders(&room_name)
|
|
||||||
};
|
|
||||||
for sender in &locals {
|
for sender in &locals {
|
||||||
match sender {
|
match sender {
|
||||||
room::ParticipantSender::Quic(t) => {
|
room::ParticipantSender::Quic(t) => {
|
||||||
if let Err(e) = t.send_raw_datagram(&media_bytes) {
|
if let Err(e) = t.send_raw_datagram(&media_bytes) {
|
||||||
fm.event_log.emit(Event::new("local_deliver_error").room(&room_name).seq(pkt.header.seq).reason(&e.to_string()));
|
fm.event_log.emit(
|
||||||
|
Event::new("local_deliver_error")
|
||||||
|
.room(&room_name)
|
||||||
|
.seq(pkt.header.seq)
|
||||||
|
.reason(&e.to_string()),
|
||||||
|
);
|
||||||
warn!("federation local delivery error: {e}");
|
warn!("federation local delivery error: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
room::ParticipantSender::WebSocket(_) => { let _ = sender.send_raw(&pkt.payload).await; }
|
room::ParticipantSender::WebSocket(_) => {
|
||||||
|
let _ = sender.send_raw(&pkt.payload).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fm.event_log.emit(Event::new("local_deliver").room(&room_name).seq(pkt.header.seq).to_count(locals.len()));
|
}
|
||||||
|
fm.event_log.emit(
|
||||||
|
Event::new("local_deliver")
|
||||||
|
.room(&room_name)
|
||||||
|
.seq(pkt.header.seq)
|
||||||
|
.to_count(locals.len()),
|
||||||
|
);
|
||||||
|
|
||||||
// Multi-hop: forward to ALL other connected peers (not the source)
|
// Multi-hop: forward to ALL other connected peers (not the source)
|
||||||
// Don't filter by active_rooms — the receiving peer decides whether to deliver
|
// Don't filter by active_rooms — the receiving peer decides whether to deliver
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
//! recv `CallOffer` → verify → generate ephemeral → derive session → send `CallAnswer`.
|
//! recv `CallOffer` → verify → generate ephemeral → derive session → send `CallAnswer`.
|
||||||
|
|
||||||
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
||||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage, default_signal_version};
|
||||||
|
|
||||||
/// Accept the relay (callee) side of the cryptographic handshake.
|
/// Accept the relay (callee) side of the cryptographic handshake.
|
||||||
///
|
///
|
||||||
@@ -20,30 +20,72 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
|||||||
pub async fn accept_handshake(
|
pub async fn accept_handshake(
|
||||||
transport: &dyn MediaTransport,
|
transport: &dyn MediaTransport,
|
||||||
seed: &[u8; 32],
|
seed: &[u8; 32],
|
||||||
) -> Result<(Box<dyn CryptoSession>, QualityProfile, String, Option<String>), anyhow::Error> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
Box<dyn CryptoSession>,
|
||||||
|
QualityProfile,
|
||||||
|
String,
|
||||||
|
Option<String>,
|
||||||
|
),
|
||||||
|
anyhow::Error,
|
||||||
|
> {
|
||||||
// 1. Receive CallOffer
|
// 1. Receive CallOffer
|
||||||
let offer = transport
|
let offer = transport
|
||||||
.recv_signal()
|
.recv_signal()
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?;
|
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?;
|
||||||
|
|
||||||
let (caller_identity_pub, caller_ephemeral_pub, caller_signature, supported_profiles, caller_alias) =
|
let (
|
||||||
match offer {
|
caller_identity_pub,
|
||||||
|
caller_ephemeral_pub,
|
||||||
|
caller_signature,
|
||||||
|
supported_profiles,
|
||||||
|
caller_alias,
|
||||||
|
protocol_version,
|
||||||
|
caller_video_codecs,
|
||||||
|
) = match offer {
|
||||||
SignalMessage::CallOffer {
|
SignalMessage::CallOffer {
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature,
|
signature,
|
||||||
supported_profiles,
|
supported_profiles,
|
||||||
alias,
|
alias,
|
||||||
} => (identity_pub, ephemeral_pub, signature, supported_profiles, alias),
|
protocol_version,
|
||||||
|
supported_versions: _,
|
||||||
|
video_codecs,
|
||||||
|
..
|
||||||
|
} => (
|
||||||
|
identity_pub,
|
||||||
|
ephemeral_pub,
|
||||||
|
signature,
|
||||||
|
supported_profiles,
|
||||||
|
alias,
|
||||||
|
protocol_version,
|
||||||
|
video_codecs,
|
||||||
|
),
|
||||||
other => {
|
other => {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"expected CallOffer, got {:?}",
|
"expected CallOffer, got {:?}",
|
||||||
std::mem::discriminant(&other)
|
std::mem::discriminant(&other)
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 1a. Protocol version check — we only speak v2.
|
||||||
|
if protocol_version != 2 {
|
||||||
|
let mismatch = SignalMessage::Hangup {
|
||||||
|
version: default_signal_version(),
|
||||||
|
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
|
||||||
|
server_supported: vec![2],
|
||||||
|
},
|
||||||
|
call_id: None,
|
||||||
|
};
|
||||||
|
let _ = transport.send_signal(&mismatch).await;
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"protocol version mismatch: client requested {protocol_version}, server supports [2]"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// 2. Verify caller's signature over (ephemeral_pub || "call-offer")
|
// 2. Verify caller's signature over (ephemeral_pub || "call-offer")
|
||||||
let mut verify_data = Vec::with_capacity(32 + 10);
|
let mut verify_data = Vec::with_capacity(32 + 10);
|
||||||
verify_data.extend_from_slice(&caller_ephemeral_pub);
|
verify_data.extend_from_slice(&caller_ephemeral_pub);
|
||||||
@@ -69,23 +111,28 @@ pub async fn accept_handshake(
|
|||||||
// Choose the best supported profile (prefer GOOD > DEGRADED > CATASTROPHIC)
|
// Choose the best supported profile (prefer GOOD > DEGRADED > CATASTROPHIC)
|
||||||
let chosen_profile = choose_profile(&supported_profiles);
|
let chosen_profile = choose_profile(&supported_profiles);
|
||||||
|
|
||||||
|
// Pick the first video codec the caller supports (relay forwards all video).
|
||||||
|
let video_codec = caller_video_codecs.into_iter().next();
|
||||||
|
|
||||||
// 6. Send CallAnswer
|
// 6. Send CallAnswer
|
||||||
let answer = SignalMessage::CallAnswer {
|
let answer = SignalMessage::CallAnswer {
|
||||||
|
version: default_signal_version(),
|
||||||
identity_pub,
|
identity_pub,
|
||||||
ephemeral_pub,
|
ephemeral_pub,
|
||||||
signature,
|
signature,
|
||||||
chosen_profile,
|
chosen_profile,
|
||||||
|
video_codec,
|
||||||
};
|
};
|
||||||
transport.send_signal(&answer).await?;
|
transport.send_signal(&answer).await?;
|
||||||
|
|
||||||
// Derive caller fingerprint: SHA-256(Ed25519 pub)[:16], formatted as xxxx:xxxx:...
|
// Derive caller fingerprint: SHA-256(Ed25519 pub)[:16], formatted as xxxx:xxxx:...
|
||||||
// Must match the format used in signal registration and presence.
|
// Must match the format used in signal registration and presence.
|
||||||
let caller_fp = {
|
let caller_fp = {
|
||||||
use sha2::{Sha256, Digest};
|
use sha2::{Digest, Sha256};
|
||||||
let hash = Sha256::digest(&caller_identity_pub);
|
let hash = Sha256::digest(&caller_identity_pub);
|
||||||
let fp = wzp_crypto::Fingerprint([
|
let fp = wzp_crypto::Fingerprint([
|
||||||
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7],
|
hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7], hash[8],
|
||||||
hash[8], hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15],
|
hash[9], hash[10], hash[11], hash[12], hash[13], hash[14], hash[15],
|
||||||
]);
|
]);
|
||||||
fp.to_string()
|
fp.to_string()
|
||||||
};
|
};
|
||||||
@@ -107,6 +154,7 @@ fn choose_profile(_supported: &[QualityProfile]) -> QualityProfile {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use wzp_proto::CodecId;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn choose_profile_picks_highest_bitrate() {
|
fn choose_profile_picks_highest_bitrate() {
|
||||||
@@ -124,4 +172,35 @@ mod tests {
|
|||||||
let chosen = choose_profile(&[]);
|
let chosen = choose_profile(&[]);
|
||||||
assert_eq!(chosen, QualityProfile::GOOD);
|
assert_eq!(chosen, QualityProfile::GOOD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Video codec negotiation ───────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn video_codec_picks_first_offered() {
|
||||||
|
let codecs = vec![CodecId::H264Baseline];
|
||||||
|
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||||
|
assert_eq!(chosen, Some(CodecId::H264Baseline));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn video_codec_none_when_no_codecs_offered() {
|
||||||
|
let codecs: Vec<CodecId> = vec![];
|
||||||
|
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||||
|
assert_eq!(chosen, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn video_codec_single_codec_is_selected() {
|
||||||
|
let codecs = vec![CodecId::H265Main];
|
||||||
|
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||||
|
assert_eq!(chosen, Some(CodecId::H265Main));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn video_codec_order_is_preserved() {
|
||||||
|
// The relay must pick the FIRST codec as-offered, not sort or re-rank.
|
||||||
|
let codecs = vec![CodecId::H264Baseline, CodecId::Av1Main];
|
||||||
|
let chosen: Option<CodecId> = codecs.into_iter().next();
|
||||||
|
assert_eq!(chosen, Some(CodecId::H264Baseline));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,22 +7,27 @@
|
|||||||
//! It operates on FEC-protected packets, managing loss recovery and adaptive
|
//! It operates on FEC-protected packets, managing loss recovery and adaptive
|
||||||
//! quality transitions.
|
//! quality transitions.
|
||||||
|
|
||||||
|
pub mod audio_scorer;
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod call_registry;
|
pub mod call_registry;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod conformance;
|
||||||
pub mod event_log;
|
pub mod event_log;
|
||||||
pub mod federation;
|
pub mod federation;
|
||||||
pub mod signal_hub;
|
|
||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
pub mod pipeline;
|
pub mod pipeline;
|
||||||
pub mod presence;
|
pub mod presence;
|
||||||
pub mod probe;
|
pub mod probe;
|
||||||
pub mod relay_link;
|
pub mod relay_link;
|
||||||
|
pub mod response_policy;
|
||||||
pub mod room;
|
pub mod room;
|
||||||
pub mod route;
|
pub mod route;
|
||||||
pub mod session_mgr;
|
pub mod session_mgr;
|
||||||
|
pub mod signal_hub;
|
||||||
pub mod trunk;
|
pub mod trunk;
|
||||||
|
pub mod verdict;
|
||||||
|
pub mod video_scorer;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
pub use config::RelayConfig;
|
pub use config::RelayConfig;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,14 @@
|
|||||||
//! Prometheus metrics for the WZP relay daemon.
|
//! Prometheus metrics for the WZP relay daemon.
|
||||||
|
|
||||||
use prometheus::{
|
use prometheus::{
|
||||||
Encoder, GaugeVec, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
|
Encoder, GaugeVec, Histogram, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge,
|
||||||
Opts, Registry, TextEncoder,
|
IntGaugeVec, Opts, Registry, TextEncoder,
|
||||||
};
|
};
|
||||||
use wzp_proto::packet::QualityReport;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use wzp_proto::MediaHeader;
|
||||||
|
use wzp_proto::packet::QualityReport;
|
||||||
|
|
||||||
|
use crate::conformance::Violation;
|
||||||
|
|
||||||
/// All relay-level Prometheus metrics.
|
/// All relay-level Prometheus metrics.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -32,6 +35,9 @@ pub struct RelayMetrics {
|
|||||||
// Phase 4: loss-recovery breakdown per session.
|
// Phase 4: loss-recovery breakdown per session.
|
||||||
pub session_dred_reconstructions: IntCounterVec,
|
pub session_dred_reconstructions: IntCounterVec,
|
||||||
pub session_classical_plc: IntCounterVec,
|
pub session_classical_plc: IntCounterVec,
|
||||||
|
pub conformance_violations: IntCounterVec,
|
||||||
|
pub conformance_bytes: HistogramVec,
|
||||||
|
pub conformance_iat_ms: HistogramVec,
|
||||||
registry: Registry,
|
registry: Registry,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,21 +46,23 @@ impl RelayMetrics {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let registry = Registry::new();
|
let registry = Registry::new();
|
||||||
|
|
||||||
let active_sessions = IntGauge::with_opts(
|
let active_sessions = IntGauge::with_opts(Opts::new(
|
||||||
Opts::new("wzp_relay_active_sessions", "Current active sessions"),
|
"wzp_relay_active_sessions",
|
||||||
)
|
"Current active sessions",
|
||||||
|
))
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
let active_rooms = IntGauge::with_opts(
|
let active_rooms =
|
||||||
Opts::new("wzp_relay_active_rooms", "Current active rooms"),
|
IntGauge::with_opts(Opts::new("wzp_relay_active_rooms", "Current active rooms"))
|
||||||
)
|
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
let packets_forwarded = IntCounter::with_opts(
|
let packets_forwarded = IntCounter::with_opts(Opts::new(
|
||||||
Opts::new("wzp_relay_packets_forwarded_total", "Total packets forwarded"),
|
"wzp_relay_packets_forwarded_total",
|
||||||
)
|
"Total packets forwarded",
|
||||||
|
))
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
let bytes_forwarded = IntCounter::with_opts(
|
let bytes_forwarded = IntCounter::with_opts(Opts::new(
|
||||||
Opts::new("wzp_relay_bytes_forwarded_total", "Total bytes forwarded"),
|
"wzp_relay_bytes_forwarded_total",
|
||||||
)
|
"Total bytes forwarded",
|
||||||
|
))
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
let auth_attempts = IntCounterVec::new(
|
let auth_attempts = IntCounterVec::new(
|
||||||
Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"),
|
Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"),
|
||||||
@@ -66,31 +74,51 @@ impl RelayMetrics {
|
|||||||
"wzp_relay_handshake_duration_seconds",
|
"wzp_relay_handshake_duration_seconds",
|
||||||
"Crypto handshake time",
|
"Crypto handshake time",
|
||||||
)
|
)
|
||||||
.buckets(vec![0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]),
|
.buckets(vec![
|
||||||
|
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5,
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
|
|
||||||
let federation_peer_status = IntGaugeVec::new(
|
let federation_peer_status = IntGaugeVec::new(
|
||||||
Opts::new("wzp_federation_peer_status", "Peer connection status (0=disconnected, 1=connected)"),
|
Opts::new(
|
||||||
|
"wzp_federation_peer_status",
|
||||||
|
"Peer connection status (0=disconnected, 1=connected)",
|
||||||
|
),
|
||||||
&["peer"],
|
&["peer"],
|
||||||
).expect("metric");
|
)
|
||||||
|
.expect("metric");
|
||||||
let federation_peer_rtt_ms = GaugeVec::new(
|
let federation_peer_rtt_ms = GaugeVec::new(
|
||||||
Opts::new("wzp_federation_peer_rtt_ms", "QUIC RTT to federated peer in milliseconds"),
|
Opts::new(
|
||||||
|
"wzp_federation_peer_rtt_ms",
|
||||||
|
"QUIC RTT to federated peer in milliseconds",
|
||||||
|
),
|
||||||
&["peer"],
|
&["peer"],
|
||||||
).expect("metric");
|
)
|
||||||
|
.expect("metric");
|
||||||
let federation_packets_forwarded = IntCounterVec::new(
|
let federation_packets_forwarded = IntCounterVec::new(
|
||||||
Opts::new("wzp_federation_packets_forwarded_total", "Packets forwarded to/from federated peers"),
|
Opts::new(
|
||||||
|
"wzp_federation_packets_forwarded_total",
|
||||||
|
"Packets forwarded to/from federated peers",
|
||||||
|
),
|
||||||
&["peer", "direction"],
|
&["peer", "direction"],
|
||||||
).expect("metric");
|
)
|
||||||
let federation_packets_deduped = IntCounter::with_opts(
|
.expect("metric");
|
||||||
Opts::new("wzp_federation_packets_deduped_total", "Duplicate federation packets dropped"),
|
let federation_packets_deduped = IntCounter::with_opts(Opts::new(
|
||||||
).expect("metric");
|
"wzp_federation_packets_deduped_total",
|
||||||
let federation_packets_rate_limited = IntCounter::with_opts(
|
"Duplicate federation packets dropped",
|
||||||
Opts::new("wzp_federation_packets_rate_limited_total", "Federation packets dropped by rate limiter"),
|
))
|
||||||
).expect("metric");
|
.expect("metric");
|
||||||
let federation_active_rooms = IntGauge::with_opts(
|
let federation_packets_rate_limited = IntCounter::with_opts(Opts::new(
|
||||||
Opts::new("wzp_federation_active_rooms", "Number of federated rooms currently active"),
|
"wzp_federation_packets_rate_limited_total",
|
||||||
).expect("metric");
|
"Federation packets dropped by rate limiter",
|
||||||
|
))
|
||||||
|
.expect("metric");
|
||||||
|
let federation_active_rooms = IntGauge::with_opts(Opts::new(
|
||||||
|
"wzp_federation_active_rooms",
|
||||||
|
"Number of federated rooms currently active",
|
||||||
|
))
|
||||||
|
.expect("metric");
|
||||||
|
|
||||||
let session_buffer_depth = IntGaugeVec::new(
|
let session_buffer_depth = IntGaugeVec::new(
|
||||||
Opts::new(
|
Opts::new(
|
||||||
@@ -109,10 +137,7 @@ impl RelayMetrics {
|
|||||||
)
|
)
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
let session_rtt_ms = GaugeVec::new(
|
let session_rtt_ms = GaugeVec::new(
|
||||||
Opts::new(
|
Opts::new("wzp_relay_session_rtt_ms", "Round-trip time per session"),
|
||||||
"wzp_relay_session_rtt_ms",
|
|
||||||
"Round-trip time per session",
|
|
||||||
),
|
|
||||||
&["session_id"],
|
&["session_id"],
|
||||||
)
|
)
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
@@ -149,26 +174,104 @@ impl RelayMetrics {
|
|||||||
&["session_id"],
|
&["session_id"],
|
||||||
)
|
)
|
||||||
.expect("metric");
|
.expect("metric");
|
||||||
|
let conformance_violations = IntCounterVec::new(
|
||||||
|
Opts::new(
|
||||||
|
"wzp_relay_conformance_violations_total",
|
||||||
|
"Conformance violations by tier, codec, media type and verdict",
|
||||||
|
),
|
||||||
|
&["tier", "codec_id", "media_type", "verdict"],
|
||||||
|
)
|
||||||
|
.expect("metric");
|
||||||
|
let conformance_bytes = HistogramVec::new(
|
||||||
|
HistogramOpts::new(
|
||||||
|
"wzp_relay_conformance_bytes_per_session",
|
||||||
|
"Packet size distribution observed by the conformance meter",
|
||||||
|
)
|
||||||
|
.buckets(vec![
|
||||||
|
16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0,
|
||||||
|
32768.0, 65536.0,
|
||||||
|
]),
|
||||||
|
&["media_type"],
|
||||||
|
)
|
||||||
|
.expect("metric");
|
||||||
|
let conformance_iat_ms = HistogramVec::new(
|
||||||
|
HistogramOpts::new(
|
||||||
|
"wzp_relay_conformance_iat_ms",
|
||||||
|
"Inter-arrival time distribution in milliseconds",
|
||||||
|
)
|
||||||
|
.buckets(vec![
|
||||||
|
1.0, 5.0, 10.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 150.0, 200.0, 300.0, 500.0,
|
||||||
|
]),
|
||||||
|
&["media_type"],
|
||||||
|
)
|
||||||
|
.expect("metric");
|
||||||
|
|
||||||
registry.register(Box::new(active_sessions.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(active_rooms.clone())).expect("register");
|
.register(Box::new(active_sessions.clone()))
|
||||||
registry.register(Box::new(packets_forwarded.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(bytes_forwarded.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(auth_attempts.clone())).expect("register");
|
.register(Box::new(active_rooms.clone()))
|
||||||
registry.register(Box::new(handshake_duration.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(federation_peer_status.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(federation_peer_rtt_ms.clone())).expect("register");
|
.register(Box::new(packets_forwarded.clone()))
|
||||||
registry.register(Box::new(federation_packets_forwarded.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(federation_packets_deduped.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(federation_packets_rate_limited.clone())).expect("register");
|
.register(Box::new(bytes_forwarded.clone()))
|
||||||
registry.register(Box::new(federation_active_rooms.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(session_buffer_depth.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(session_loss_pct.clone())).expect("register");
|
.register(Box::new(auth_attempts.clone()))
|
||||||
registry.register(Box::new(session_rtt_ms.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(session_underruns.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(session_overruns.clone())).expect("register");
|
.register(Box::new(handshake_duration.clone()))
|
||||||
registry.register(Box::new(session_dred_reconstructions.clone())).expect("register");
|
.expect("register");
|
||||||
registry.register(Box::new(session_classical_plc.clone())).expect("register");
|
registry
|
||||||
|
.register(Box::new(federation_peer_status.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(federation_peer_rtt_ms.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(federation_packets_forwarded.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(federation_packets_deduped.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(federation_packets_rate_limited.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(federation_active_rooms.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_buffer_depth.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_loss_pct.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_rtt_ms.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_underruns.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_overruns.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_dred_reconstructions.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(session_classical_plc.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(conformance_violations.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(conformance_bytes.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(conformance_iat_ms.clone()))
|
||||||
|
.expect("register");
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
active_sessions,
|
active_sessions,
|
||||||
@@ -190,6 +293,9 @@ impl RelayMetrics {
|
|||||||
session_overruns,
|
session_overruns,
|
||||||
session_dred_reconstructions,
|
session_dred_reconstructions,
|
||||||
session_classical_plc,
|
session_classical_plc,
|
||||||
|
conformance_violations,
|
||||||
|
conformance_bytes,
|
||||||
|
conformance_iat_ms,
|
||||||
registry,
|
registry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -230,10 +336,7 @@ impl RelayMetrics {
|
|||||||
.with_label_values(&[session_id])
|
.with_label_values(&[session_id])
|
||||||
.inc_by(underruns - cur_underruns as u64);
|
.inc_by(underruns - cur_underruns as u64);
|
||||||
}
|
}
|
||||||
let cur_overruns = self
|
let cur_overruns = self.session_overruns.with_label_values(&[session_id]).get();
|
||||||
.session_overruns
|
|
||||||
.with_label_values(&[session_id])
|
|
||||||
.get();
|
|
||||||
if overruns > cur_overruns as u64 {
|
if overruns > cur_overruns as u64 {
|
||||||
self.session_overruns
|
self.session_overruns
|
||||||
.with_label_values(&[session_id])
|
.with_label_values(&[session_id])
|
||||||
@@ -274,6 +377,45 @@ impl RelayMetrics {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Record conformance-related metrics for a single received packet.
|
||||||
|
///
|
||||||
|
/// * `header` — the media header (provides codec_id and media_type).
|
||||||
|
/// * `payload_len` — payload length in bytes.
|
||||||
|
/// * `iat_ms` — inter-arrival time since the previous packet.
|
||||||
|
/// * `violation` — `Some(Violation)` if the packet triggered a conformance
|
||||||
|
/// limit; `None` for clean packets.
|
||||||
|
pub fn record_conformance(
|
||||||
|
&self,
|
||||||
|
header: &MediaHeader,
|
||||||
|
payload_len: usize,
|
||||||
|
iat_ms: u64,
|
||||||
|
violation: Option<Violation>,
|
||||||
|
) {
|
||||||
|
let media_type = format!("{:?}", header.media_type);
|
||||||
|
let bytes = (MediaHeader::WIRE_SIZE + payload_len) as f64;
|
||||||
|
self.conformance_bytes
|
||||||
|
.with_label_values(&[&media_type])
|
||||||
|
.observe(bytes);
|
||||||
|
self.conformance_iat_ms
|
||||||
|
.with_label_values(&[&media_type])
|
||||||
|
.observe(iat_ms as f64);
|
||||||
|
|
||||||
|
if let Some(v) = violation {
|
||||||
|
let tier = match v {
|
||||||
|
Violation::BitrateExceeded => "A",
|
||||||
|
Violation::PacketRateExceeded => "B",
|
||||||
|
Violation::TimestampDrift => "C",
|
||||||
|
Violation::PayloadSizeExceeded => "D",
|
||||||
|
Violation::RateCapExceeded => "E",
|
||||||
|
};
|
||||||
|
let codec_id = format!("{:?}", header.codec_id);
|
||||||
|
let verdict = format!("{:?}", v);
|
||||||
|
self.conformance_violations
|
||||||
|
.with_label_values(&[tier, &codec_id, &media_type, &verdict])
|
||||||
|
.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Remove all per-session label values for a disconnected session.
|
/// Remove all per-session label values for a disconnected session.
|
||||||
pub fn remove_session_metrics(&self, session_id: &str) {
|
pub fn remove_session_metrics(&self, session_id: &str) {
|
||||||
let _ = self.session_buffer_depth.remove_label_values(&[session_id]);
|
let _ = self.session_buffer_depth.remove_label_values(&[session_id]);
|
||||||
@@ -284,7 +426,9 @@ impl RelayMetrics {
|
|||||||
let _ = self
|
let _ = self
|
||||||
.session_dred_reconstructions
|
.session_dred_reconstructions
|
||||||
.remove_label_values(&[session_id]);
|
.remove_label_values(&[session_id]);
|
||||||
let _ = self.session_classical_plc.remove_label_values(&[session_id]);
|
let _ = self
|
||||||
|
.session_classical_plc
|
||||||
|
.remove_label_values(&[session_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a reference to the underlying Prometheus registry.
|
/// Get a reference to the underlying Prometheus registry.
|
||||||
@@ -298,7 +442,9 @@ impl RelayMetrics {
|
|||||||
let encoder = TextEncoder::new();
|
let encoder = TextEncoder::new();
|
||||||
let metric_families = self.registry.gather();
|
let metric_families = self.registry.gather();
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
encoder.encode(&metric_families, &mut buffer).expect("encode");
|
encoder
|
||||||
|
.encode(&metric_families, &mut buffer)
|
||||||
|
.expect("encode");
|
||||||
String::from_utf8(buffer).expect("utf8")
|
String::from_utf8(buffer).expect("utf8")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -310,7 +456,7 @@ pub async fn serve_metrics(
|
|||||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||||
route_resolver: Option<Arc<crate::route::RouteResolver>>,
|
route_resolver: Option<Arc<crate::route::RouteResolver>>,
|
||||||
) {
|
) {
|
||||||
use axum::{extract::Path, routing::get, Router};
|
use axum::{Router, extract::Path, routing::get};
|
||||||
|
|
||||||
let metrics_clone = metrics.clone();
|
let metrics_clone = metrics.clone();
|
||||||
let presence_all = presence.clone();
|
let presence_all = presence.clone();
|
||||||
|
|||||||
@@ -11,11 +11,11 @@
|
|||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||||
|
use wzp_proto::QualityProfile;
|
||||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||||
use wzp_proto::packet::{MediaHeader, MediaPacket};
|
use wzp_proto::packet::{MediaHeader, MediaPacket};
|
||||||
use wzp_proto::quality::AdaptiveQualityController;
|
use wzp_proto::quality::AdaptiveQualityController;
|
||||||
use wzp_proto::traits::{FecDecoder, FecEncoder, QualityController};
|
use wzp_proto::traits::{FecDecoder, FecEncoder, QualityController};
|
||||||
use wzp_proto::QualityProfile;
|
|
||||||
|
|
||||||
/// Configuration for a relay pipeline instance.
|
/// Configuration for a relay pipeline instance.
|
||||||
pub struct PipelineConfig {
|
pub struct PipelineConfig {
|
||||||
@@ -51,7 +51,7 @@ pub struct RelayPipeline {
|
|||||||
/// Current quality profile.
|
/// Current quality profile.
|
||||||
profile: QualityProfile,
|
profile: QualityProfile,
|
||||||
/// Outbound sequence counter.
|
/// Outbound sequence counter.
|
||||||
out_seq: u16,
|
out_seq: u32,
|
||||||
/// Packets processed count.
|
/// Packets processed count.
|
||||||
stats: PipelineStats,
|
stats: PipelineStats,
|
||||||
}
|
}
|
||||||
@@ -111,8 +111,8 @@ impl RelayPipeline {
|
|||||||
let header = &packet.header;
|
let header = &packet.header;
|
||||||
let _ = self.fec_decoder.add_symbol(
|
let _ = self.fec_decoder.add_symbol(
|
||||||
header.fec_block,
|
header.fec_block,
|
||||||
header.fec_symbol,
|
header.fec_block >> 8,
|
||||||
header.is_repair,
|
header.is_repair(),
|
||||||
&packet.payload,
|
&packet.payload,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -128,22 +128,21 @@ impl RelayPipeline {
|
|||||||
for (i, frame) in frames.into_iter().enumerate() {
|
for (i, frame) in frames.into_iter().enumerate() {
|
||||||
let reconstructed = MediaPacket {
|
let reconstructed = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: wzp_proto::MediaType::Audio,
|
||||||
codec_id: header.codec_id,
|
codec_id: header.codec_id,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: header.fec_ratio_encoded,
|
fec_ratio: header.fec_ratio,
|
||||||
// Reconstruct seq from block + symbol index
|
// Reconstruct seq from block + symbol index
|
||||||
seq: (header.fec_block as u16)
|
seq: (header.fec_block as u32)
|
||||||
.wrapping_mul(self.profile.frames_per_block as u16)
|
.wrapping_mul(self.profile.frames_per_block as u32)
|
||||||
.wrapping_add(i as u16),
|
.wrapping_add(i as u32),
|
||||||
timestamp: header
|
timestamp: header.timestamp.wrapping_add(
|
||||||
.timestamp
|
(i as u32) * (header.codec_id.frame_duration_ms() as u32),
|
||||||
.wrapping_add((i as u32) * (header.codec_id.frame_duration_ms() as u32)),
|
),
|
||||||
fec_block: header.fec_block,
|
fec_block: u16::from((header.fec_block & 0xFF) as u8)
|
||||||
fec_symbol: i as u8,
|
| (u16::from(i as u8) << 8),
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: bytes::Bytes::from(frame),
|
payload: bytes::Bytes::from(frame),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -191,19 +190,16 @@ impl RelayPipeline {
|
|||||||
for (sym_idx, repair_data) in repairs {
|
for (sym_idx, repair_data) in repairs {
|
||||||
let repair_packet = MediaPacket {
|
let repair_packet = MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: true,
|
flags: MediaHeader::FLAG_REPAIR,
|
||||||
|
media_type: wzp_proto::MediaType::Audio,
|
||||||
codec_id: packet.header.codec_id,
|
codec_id: packet.header.codec_id,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
fec_ratio: MediaHeader::encode_fec_ratio(self.profile.fec_ratio),
|
||||||
self.profile.fec_ratio,
|
|
||||||
),
|
|
||||||
seq: self.out_seq,
|
seq: self.out_seq,
|
||||||
timestamp: packet.header.timestamp,
|
timestamp: packet.header.timestamp,
|
||||||
fec_block: self.fec_encoder.current_block_id(),
|
fec_block: u16::from(self.fec_encoder.current_block_id())
|
||||||
fec_symbol: sym_idx,
|
| (u16::from(sym_idx) << 8),
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: bytes::Bytes::from(repair_data),
|
payload: bytes::Bytes::from(repair_data),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -232,23 +228,21 @@ impl RelayPipeline {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use wzp_proto::CodecId;
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use wzp_proto::CodecId;
|
||||||
|
|
||||||
fn make_media_packet(seq: u16, block: u8, symbol: u8) -> MediaPacket {
|
fn make_media_packet(seq: u32, block: u8, symbol: u8) -> MediaPacket {
|
||||||
MediaPacket {
|
MediaPacket {
|
||||||
header: MediaHeader {
|
header: MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: wzp_proto::MediaType::Audio,
|
||||||
codec_id: CodecId::Opus24k,
|
codec_id: CodecId::Opus24k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq,
|
seq,
|
||||||
timestamp: seq as u32 * 20,
|
timestamp: seq * 20,
|
||||||
fec_block: block,
|
fec_block: u16::from(block) | (u16::from(symbol) << 8),
|
||||||
fec_symbol: symbol,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: Bytes::from(vec![seq as u8; 60]),
|
payload: Bytes::from(vec![seq as u8; 60]),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
@@ -283,7 +277,7 @@ mod tests {
|
|||||||
|
|
||||||
// Feed 5 packets (one full block)
|
// Feed 5 packets (one full block)
|
||||||
let mut total_out = 0;
|
let mut total_out = 0;
|
||||||
for i in 0..5u16 {
|
for i in 0..5u32 {
|
||||||
let pkt = make_media_packet(i, 0, i as u8);
|
let pkt = make_media_packet(i, 0, i as u8);
|
||||||
let out = pipeline.prepare_outbound(pkt);
|
let out = pipeline.prepare_outbound(pkt);
|
||||||
total_out += out.len();
|
total_out += out.len();
|
||||||
|
|||||||
@@ -63,6 +63,12 @@ pub struct PresenceRegistry {
|
|||||||
peers: HashMap<SocketAddr, PeerRelay>,
|
peers: HashMap<SocketAddr, PeerRelay>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for PresenceRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PresenceRegistry {
|
impl PresenceRegistry {
|
||||||
/// Create an empty registry.
|
/// Create an empty registry.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
@@ -74,13 +80,21 @@ impl PresenceRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register a fingerprint as locally connected (called after auth + handshake).
|
/// Register a fingerprint as locally connected (called after auth + handshake).
|
||||||
pub fn register_local(&mut self, fingerprint: &str, alias: Option<String>, room: Option<String>) {
|
pub fn register_local(
|
||||||
self.local.insert(fingerprint.to_string(), LocalPresence {
|
&mut self,
|
||||||
|
fingerprint: &str,
|
||||||
|
alias: Option<String>,
|
||||||
|
room: Option<String>,
|
||||||
|
) {
|
||||||
|
self.local.insert(
|
||||||
|
fingerprint.to_string(),
|
||||||
|
LocalPresence {
|
||||||
fingerprint: fingerprint.to_string(),
|
fingerprint: fingerprint.to_string(),
|
||||||
alias,
|
alias,
|
||||||
connected_at: Instant::now(),
|
connected_at: Instant::now(),
|
||||||
room,
|
room,
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unregister a locally connected fingerprint (called on disconnect).
|
/// Unregister a locally connected fingerprint (called on disconnect).
|
||||||
@@ -98,11 +112,14 @@ impl PresenceRegistry {
|
|||||||
|
|
||||||
// Insert new remote entries
|
// Insert new remote entries
|
||||||
for fp in &fingerprints {
|
for fp in &fingerprints {
|
||||||
self.remote.insert(fp.clone(), RemotePresence {
|
self.remote.insert(
|
||||||
|
fp.clone(),
|
||||||
|
RemotePresence {
|
||||||
fingerprint: fp.clone(),
|
fingerprint: fp.clone(),
|
||||||
relay_addr: addr,
|
relay_addr: addr,
|
||||||
last_seen: now,
|
last_seen: now,
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the peer record
|
// Update the peer record
|
||||||
@@ -156,7 +173,8 @@ impl PresenceRegistry {
|
|||||||
self.remote.retain(|_, rp| rp.last_seen > cutoff);
|
self.remote.retain(|_, rp| rp.last_seen > cutoff);
|
||||||
|
|
||||||
// Expire peer relay records and their fingerprint sets
|
// Expire peer relay records and their fingerprint sets
|
||||||
let stale_peers: Vec<SocketAddr> = self.peers
|
let stale_peers: Vec<SocketAddr> = self
|
||||||
|
.peers
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(_, p)| p.last_update <= cutoff)
|
.filter(|(_, p)| p.last_update <= cutoff)
|
||||||
.map(|(addr, _)| *addr)
|
.map(|(addr, _)| *addr)
|
||||||
@@ -280,13 +298,15 @@ mod tests {
|
|||||||
let all = reg.all_known();
|
let all = reg.all_known();
|
||||||
assert_eq!(all.len(), 2);
|
assert_eq!(all.len(), 2);
|
||||||
|
|
||||||
let local_entries: Vec<_> = all.iter()
|
let local_entries: Vec<_> = all
|
||||||
|
.iter()
|
||||||
.filter(|(_, loc)| *loc == PresenceLocation::Local)
|
.filter(|(_, loc)| *loc == PresenceLocation::Local)
|
||||||
.collect();
|
.collect();
|
||||||
assert_eq!(local_entries.len(), 1);
|
assert_eq!(local_entries.len(), 1);
|
||||||
assert_eq!(local_entries[0].0, "local1");
|
assert_eq!(local_entries[0].0, "local1");
|
||||||
|
|
||||||
let remote_entries: Vec<_> = all.iter()
|
let remote_entries: Vec<_> = all
|
||||||
|
.iter()
|
||||||
.filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_)))
|
.filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_)))
|
||||||
.collect();
|
.collect();
|
||||||
assert_eq!(remote_entries.len(), 1);
|
assert_eq!(remote_entries.len(), 1);
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use prometheus::{Gauge, IntGauge, Opts, Registry};
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
use wzp_proto::{MediaTransport, SignalMessage};
|
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||||
|
|
||||||
/// Configuration for a single probe target.
|
/// Configuration for a single probe target.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -43,8 +43,7 @@ impl ProbeMetrics {
|
|||||||
/// Register probe metrics with the given `target` label value.
|
/// Register probe metrics with the given `target` label value.
|
||||||
pub fn register(target: &str, registry: &Registry) -> Self {
|
pub fn register(target: &str, registry: &Registry) -> Self {
|
||||||
let rtt_ms = Gauge::with_opts(
|
let rtt_ms = Gauge::with_opts(
|
||||||
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms")
|
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms").const_label("target", target),
|
||||||
.const_label("target", target),
|
|
||||||
)
|
)
|
||||||
.expect("probe metric");
|
.expect("probe metric");
|
||||||
|
|
||||||
@@ -66,9 +65,15 @@ impl ProbeMetrics {
|
|||||||
)
|
)
|
||||||
.expect("probe metric");
|
.expect("probe metric");
|
||||||
|
|
||||||
registry.register(Box::new(rtt_ms.clone())).expect("register");
|
registry
|
||||||
registry.register(Box::new(loss_pct.clone())).expect("register");
|
.register(Box::new(rtt_ms.clone()))
|
||||||
registry.register(Box::new(jitter_ms.clone())).expect("register");
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(loss_pct.clone()))
|
||||||
|
.expect("register");
|
||||||
|
registry
|
||||||
|
.register(Box::new(jitter_ms.clone()))
|
||||||
|
.expect("register");
|
||||||
registry.register(Box::new(up.clone())).expect("register");
|
registry.register(Box::new(up.clone())).expect("register");
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@@ -168,7 +173,11 @@ impl ProbeRunner {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
let target_str = config.target.to_string();
|
let target_str = config.target.to_string();
|
||||||
let metrics = ProbeMetrics::register(&target_str, registry);
|
let metrics = ProbeMetrics::register(&target_str, registry);
|
||||||
Self { config, metrics, presence }
|
Self {
|
||||||
|
config,
|
||||||
|
metrics,
|
||||||
|
presence,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the probe forever. This function never returns under normal operation.
|
/// Run the probe forever. This function never returns under normal operation.
|
||||||
@@ -198,13 +207,8 @@ impl ProbeRunner {
|
|||||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||||
let client_cfg = wzp_transport::client_config();
|
let client_cfg = wzp_transport::client_config();
|
||||||
let conn = wzp_transport::connect(
|
let conn =
|
||||||
&endpoint,
|
wzp_transport::connect(&endpoint, self.config.target, "_probe", client_cfg).await?;
|
||||||
self.config.target,
|
|
||||||
"_probe",
|
|
||||||
client_cfg,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||||
self.metrics.up.set(1);
|
self.metrics.up.set(1);
|
||||||
@@ -225,7 +229,7 @@ impl ProbeRunner {
|
|||||||
let recv_handle = tokio::spawn(async move {
|
let recv_handle = tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
match recv_transport.recv_signal().await {
|
match recv_transport.recv_signal().await {
|
||||||
Ok(Some(SignalMessage::Pong { timestamp_ms })) => {
|
Ok(Some(SignalMessage::Pong { timestamp_ms, .. })) => {
|
||||||
let now_ms = SystemTime::now()
|
let now_ms = SystemTime::now()
|
||||||
.duration_since(UNIX_EPOCH)
|
.duration_since(UNIX_EPOCH)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -237,11 +241,16 @@ impl ProbeRunner {
|
|||||||
loss_gauge.set(w.loss_pct());
|
loss_gauge.set(w.loss_pct());
|
||||||
jitter_gauge.set(w.jitter_ms());
|
jitter_gauge.set(w.jitter_ms());
|
||||||
}
|
}
|
||||||
Ok(Some(SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => {
|
Ok(Some(SignalMessage::PresenceUpdate {
|
||||||
|
fingerprints,
|
||||||
|
relay_addr,
|
||||||
|
..
|
||||||
|
})) => {
|
||||||
if let Some(ref reg) = recv_presence {
|
if let Some(ref reg) = recv_presence {
|
||||||
// Parse the relay_addr; fall back to the connection target
|
// Parse the relay_addr; fall back to the connection target
|
||||||
let addr = relay_addr.parse().unwrap_or(recv_target);
|
let addr = relay_addr.parse().unwrap_or(recv_target);
|
||||||
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect();
|
let fps: std::collections::HashSet<String> =
|
||||||
|
fingerprints.into_iter().collect();
|
||||||
let mut r = reg.lock().await;
|
let mut r = reg.lock().await;
|
||||||
r.update_peer(addr, fps);
|
r.update_peer(addr, fps);
|
||||||
}
|
}
|
||||||
@@ -285,7 +294,10 @@ impl ProbeRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = transport
|
if let Err(e) = transport
|
||||||
.send_signal(&SignalMessage::Ping { timestamp_ms })
|
.send_signal(&SignalMessage::Ping {
|
||||||
|
version: default_signal_version(),
|
||||||
|
timestamp_ms,
|
||||||
|
})
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
error!(target = %self.config.target, "probe ping send error: {e}");
|
error!(target = %self.config.target, "probe ping send error: {e}");
|
||||||
@@ -302,6 +314,7 @@ impl ProbeRunner {
|
|||||||
r.local_fingerprints().into_iter().collect()
|
r.local_fingerprints().into_iter().collect()
|
||||||
};
|
};
|
||||||
let msg = SignalMessage::PresenceUpdate {
|
let msg = SignalMessage::PresenceUpdate {
|
||||||
|
version: default_signal_version(),
|
||||||
fingerprints: fps,
|
fingerprints: fps,
|
||||||
relay_addr: self.config.target.to_string(),
|
relay_addr: self.config.target.to_string(),
|
||||||
};
|
};
|
||||||
@@ -374,10 +387,7 @@ pub fn mesh_summary(registry: &Registry) -> String {
|
|||||||
let name = family.get_name();
|
let name = family.get_name();
|
||||||
for metric in family.get_metric() {
|
for metric in family.get_metric() {
|
||||||
// Find the "target" label
|
// Find the "target" label
|
||||||
let target_label = metric
|
let target_label = metric.get_label().iter().find(|l| l.get_name() == "target");
|
||||||
.get_label()
|
|
||||||
.iter()
|
|
||||||
.find(|l| l.get_name() == "target");
|
|
||||||
let target = match target_label {
|
let target = match target_label {
|
||||||
Some(l) => l.get_value().to_string(),
|
Some(l) => l.get_value().to_string(),
|
||||||
None => continue,
|
None => continue,
|
||||||
@@ -420,13 +430,11 @@ pub fn mesh_summary(registry: &Registry) -> String {
|
|||||||
|
|
||||||
/// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp.
|
/// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp.
|
||||||
/// Returns true if the message was a Ping and was handled, false otherwise.
|
/// Returns true if the message was a Ping and was handled, false otherwise.
|
||||||
pub async fn handle_ping(
|
pub async fn handle_ping(transport: &wzp_transport::QuinnTransport, msg: &SignalMessage) -> bool {
|
||||||
transport: &wzp_transport::QuinnTransport,
|
if let SignalMessage::Ping { timestamp_ms, .. } = msg {
|
||||||
msg: &SignalMessage,
|
|
||||||
) -> bool {
|
|
||||||
if let SignalMessage::Ping { timestamp_ms } = msg {
|
|
||||||
if let Err(e) = transport
|
if let Err(e) = transport
|
||||||
.send_signal(&SignalMessage::Pong {
|
.send_signal(&SignalMessage::Pong {
|
||||||
|
version: default_signal_version(),
|
||||||
timestamp_ms: *timestamp_ms,
|
timestamp_ms: *timestamp_ms,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
@@ -456,9 +464,18 @@ mod tests {
|
|||||||
encoder.encode(&families, &mut buf).unwrap();
|
encoder.encode(&families, &mut buf).unwrap();
|
||||||
let output = String::from_utf8(buf).unwrap();
|
let output = String::from_utf8(buf).unwrap();
|
||||||
|
|
||||||
assert!(output.contains("wzp_probe_rtt_ms"), "missing wzp_probe_rtt_ms");
|
assert!(
|
||||||
assert!(output.contains("wzp_probe_loss_pct"), "missing wzp_probe_loss_pct");
|
output.contains("wzp_probe_rtt_ms"),
|
||||||
assert!(output.contains("wzp_probe_jitter_ms"), "missing wzp_probe_jitter_ms");
|
"missing wzp_probe_rtt_ms"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
output.contains("wzp_probe_loss_pct"),
|
||||||
|
"missing wzp_probe_loss_pct"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
output.contains("wzp_probe_jitter_ms"),
|
||||||
|
"missing wzp_probe_jitter_ms"
|
||||||
|
);
|
||||||
assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up");
|
assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up");
|
||||||
assert!(
|
assert!(
|
||||||
output.contains("target=\"127.0.0.1:4433\""),
|
output.contains("target=\"127.0.0.1:4433\""),
|
||||||
|
|||||||
@@ -40,10 +40,7 @@ impl RelayLink {
|
|||||||
/// should skip normal client auth/handshake for relay-SNI connections.
|
/// should skip normal client auth/handshake for relay-SNI connections.
|
||||||
pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> {
|
pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> {
|
||||||
// Create a client-only endpoint on an OS-assigned port.
|
// Create a client-only endpoint on an OS-assigned port.
|
||||||
let endpoint = wzp_transport::create_endpoint(
|
let endpoint = wzp_transport::create_endpoint("0.0.0.0:0".parse().unwrap(), None)?;
|
||||||
"0.0.0.0:0".parse().unwrap(),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let client_cfg = wzp_transport::client_config();
|
let client_cfg = wzp_transport::client_config();
|
||||||
let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?;
|
let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?;
|
||||||
@@ -336,10 +333,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn session_forward_signal_roundtrip() {
|
fn session_forward_signal_roundtrip() {
|
||||||
use wzp_proto::SignalMessage;
|
use wzp_proto::{SignalMessage, default_signal_version};
|
||||||
|
|
||||||
// SessionForward roundtrip
|
// SessionForward roundtrip
|
||||||
let msg = SignalMessage::SessionForward {
|
let msg = SignalMessage::SessionForward {
|
||||||
|
version: default_signal_version(),
|
||||||
session_id: "abcd1234".to_string(),
|
session_id: "abcd1234".to_string(),
|
||||||
target_fingerprint: "deadbeef".to_string(),
|
target_fingerprint: "deadbeef".to_string(),
|
||||||
source_relay: "10.0.0.1:4433".to_string(),
|
source_relay: "10.0.0.1:4433".to_string(),
|
||||||
@@ -351,6 +349,7 @@ mod tests {
|
|||||||
session_id,
|
session_id,
|
||||||
target_fingerprint,
|
target_fingerprint,
|
||||||
source_relay,
|
source_relay,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(session_id, "abcd1234");
|
assert_eq!(session_id, "abcd1234");
|
||||||
assert_eq!(target_fingerprint, "deadbeef");
|
assert_eq!(target_fingerprint, "deadbeef");
|
||||||
@@ -361,6 +360,7 @@ mod tests {
|
|||||||
|
|
||||||
// SessionForwardAck roundtrip
|
// SessionForwardAck roundtrip
|
||||||
let ack = SignalMessage::SessionForwardAck {
|
let ack = SignalMessage::SessionForwardAck {
|
||||||
|
version: default_signal_version(),
|
||||||
session_id: "abcd1234".to_string(),
|
session_id: "abcd1234".to_string(),
|
||||||
room_name: "relay-room-42".to_string(),
|
room_name: "relay-room-42".to_string(),
|
||||||
};
|
};
|
||||||
@@ -370,6 +370,7 @@ mod tests {
|
|||||||
SignalMessage::SessionForwardAck {
|
SignalMessage::SessionForwardAck {
|
||||||
session_id,
|
session_id,
|
||||||
room_name,
|
room_name,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(session_id, "abcd1234");
|
assert_eq!(session_id, "abcd1234");
|
||||||
assert_eq!(room_name, "relay-room-42");
|
assert_eq!(room_name, "relay-room-42");
|
||||||
@@ -457,17 +458,15 @@ mod tests {
|
|||||||
|
|
||||||
let pkt = MediaPacket {
|
let pkt = MediaPacket {
|
||||||
header: wzp_proto::packet::MediaHeader {
|
header: wzp_proto::packet::MediaHeader {
|
||||||
version: 0,
|
version: 2,
|
||||||
is_repair: false,
|
flags: 0,
|
||||||
|
media_type: wzp_proto::MediaType::Audio,
|
||||||
codec_id: wzp_proto::CodecId::Opus16k,
|
codec_id: wzp_proto::CodecId::Opus16k,
|
||||||
has_quality_report: false,
|
stream_id: 0,
|
||||||
fec_ratio_encoded: 0,
|
fec_ratio: 0,
|
||||||
seq: 1,
|
seq: 1,
|
||||||
timestamp: 100,
|
timestamp: 100,
|
||||||
fec_block: 0,
|
fec_block: 0,
|
||||||
fec_symbol: 0,
|
|
||||||
reserved: 0,
|
|
||||||
csrc_count: 0,
|
|
||||||
},
|
},
|
||||||
payload: bytes::Bytes::from_static(b"test"),
|
payload: bytes::Bytes::from_static(b"test"),
|
||||||
quality_report: None,
|
quality_report: None,
|
||||||
|
|||||||
207
crates/wzp-relay/src/response_policy.rs
Normal file
207
crates/wzp-relay/src/response_policy.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//! Tier G response policy — maps conformance verdicts to enforcement actions.
|
||||||
|
//!
|
||||||
|
//! Actions:
|
||||||
|
//! - `Legitimate` → no action
|
||||||
|
//! - `Suspect` → tighten Tier E quota, emit metric
|
||||||
|
//! - `Abusive` → typed Hangup + 1 h fingerprint cool-down
|
||||||
|
//! - `RepeatAbusive` → relay-local block 24 h
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use wzp_proto::packet::{HangupReason, ViolationCode};
|
||||||
|
|
||||||
|
use crate::verdict::Verdict;
|
||||||
|
|
||||||
|
/// Enforcement action recommended by the response policy.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub enum Action {
|
||||||
|
/// Pass through unchanged.
|
||||||
|
Allow,
|
||||||
|
/// Throttle to tighter quota (Tier E).
|
||||||
|
Throttle,
|
||||||
|
/// Close the session with a typed Hangup signal.
|
||||||
|
Close { reason: HangupReason },
|
||||||
|
/// Block the fingerprint from joining any room for 24 h.
|
||||||
|
Block,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tracks fingerprint-level abuse history and applies escalation.
|
||||||
|
pub struct ResponsePolicy {
|
||||||
|
/// `(fingerprint, violation_code)` → last abusive instant.
|
||||||
|
cooldowns: HashMap<(String, ViolationCode), Instant>,
|
||||||
|
/// Block duration for repeat abuse.
|
||||||
|
block_duration: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResponsePolicy {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
cooldowns: HashMap::new(),
|
||||||
|
block_duration: Duration::from_secs(86400), // 24 h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate a verdict and produce the corresponding [`Action`].
|
||||||
|
///
|
||||||
|
/// `fingerprint` is the participant's identity string (or IP as fallback).
|
||||||
|
/// `code` is the specific violation type that triggered the verdict.
|
||||||
|
pub fn evaluate(&mut self, fingerprint: &str, code: ViolationCode, verdict: Verdict) -> Action {
|
||||||
|
match verdict {
|
||||||
|
Verdict::Legitimate => Action::Allow,
|
||||||
|
Verdict::Suspect => Action::Throttle,
|
||||||
|
Verdict::Abusive => {
|
||||||
|
let key = (fingerprint.to_string(), code);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// Check if this fingerprint was already abusive recently.
|
||||||
|
let is_repeat = self
|
||||||
|
.cooldowns
|
||||||
|
.get(&key)
|
||||||
|
.map(|last| now.duration_since(*last) < self.block_duration)
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if is_repeat {
|
||||||
|
Action::Block
|
||||||
|
} else {
|
||||||
|
self.cooldowns.insert(key, now);
|
||||||
|
Action::Close {
|
||||||
|
reason: HangupReason::PolicyViolation {
|
||||||
|
code,
|
||||||
|
reason: format!("Tier G enforcement: {code:?}"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the fingerprint is currently blocked (repeat abuse).
|
||||||
|
pub fn is_blocked(&self, fingerprint: &str) -> bool {
|
||||||
|
let now = Instant::now();
|
||||||
|
self.cooldowns.iter().any(|((fp, _), last)| {
|
||||||
|
fp == fingerprint && now.duration_since(*last) < self.block_duration
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean up expired cooldown entries.
|
||||||
|
pub fn prune(&mut self) {
|
||||||
|
let now = Instant::now();
|
||||||
|
self.cooldowns
|
||||||
|
.retain(|_, last| now.duration_since(*last) < self.block_duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of tracked cooldown entries.
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.cooldowns.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.cooldowns.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ResponsePolicy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn legitimate_allowed() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
assert_eq!(
|
||||||
|
policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Legitimate),
|
||||||
|
Action::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn suspect_throttled() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
assert_eq!(
|
||||||
|
policy.evaluate("alice", ViolationCode::Entropy, Verdict::Suspect),
|
||||||
|
Action::Throttle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn abusive_gets_close() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
assert!(
|
||||||
|
matches!(action, Action::Close { .. }),
|
||||||
|
"first-time abuse should close session"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn repeat_abusive_gets_block() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
// First abuse
|
||||||
|
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
// Second abuse within window → block
|
||||||
|
let action = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
assert_eq!(action, Action::Block, "repeat abuse should block");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn different_violation_codes_are_independent() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
// Abuse on bitrate
|
||||||
|
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
// Abuse on entropy is treated as first-time for that code
|
||||||
|
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
|
||||||
|
assert!(
|
||||||
|
matches!(action, Action::Close { .. }),
|
||||||
|
"different violation code should not trigger repeat"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_blocked_true_after_repeat() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
assert!(policy.is_blocked("alice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_blocked_false_for_legitimate() {
|
||||||
|
let policy = ResponsePolicy::new();
|
||||||
|
assert!(!policy.is_blocked("alice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prune_removes_expired() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
let _ = policy.evaluate("alice", ViolationCode::Bitrate, Verdict::Abusive);
|
||||||
|
assert_eq!(policy.len(), 1);
|
||||||
|
// Manually expire by moving cooldown back
|
||||||
|
policy.cooldowns.insert(
|
||||||
|
("alice".to_string(), ViolationCode::Bitrate),
|
||||||
|
Instant::now() - Duration::from_secs(90000),
|
||||||
|
);
|
||||||
|
policy.prune();
|
||||||
|
assert!(policy.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn close_reason_contains_code() {
|
||||||
|
let mut policy = ResponsePolicy::new();
|
||||||
|
let action = policy.evaluate("alice", ViolationCode::Entropy, Verdict::Abusive);
|
||||||
|
match action {
|
||||||
|
Action::Close { reason } => match reason {
|
||||||
|
HangupReason::PolicyViolation { code, .. } => {
|
||||||
|
assert_eq!(code, ViolationCode::Entropy);
|
||||||
|
}
|
||||||
|
other => panic!("expected PolicyViolation, got {other:?}"),
|
||||||
|
},
|
||||||
|
other => panic!("expected Close, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user