diff --git a/Cargo.lock b/Cargo.lock index 27ae857..b9e2e30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -291,12 +291,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - [[package]] name = "base64" version = "0.22.1" @@ -467,7 +461,6 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-link", ] @@ -628,24 +621,6 @@ dependencies = [ "libc", ] -[[package]] -name = "crunchy" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" - -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core 0.6.4", - "subtle", - "zeroize", -] - [[package]] name = "crypto-common" version = "0.1.7" @@ -669,7 +644,6 @@ dependencies = [ "digest", "fiat-crypto", "rustc_version", - "serde", "subtle", "zeroize", ] @@ -836,7 +810,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", - "const-oid", "crypto-common", "subtle", ] @@ -871,21 +844,6 @@ dependencies = [ "rustfft", ] -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest", - "elliptic-curve", - "rfc6979", - "serdect", - "signature", - "spki", -] - [[package]] name = "ed25519" version = "2.2.3" @@ -893,7 +851,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" dependencies = [ "pkcs8", - "serde", "signature", ] @@ -918,26 +875,6 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" -[[package]] -name = "elliptic-curve" -version = "0.13.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest", - "ff", - "generic-array", - "group", - "pkcs8", - "rand_core 0.6.4", - "sec1", - "serdect", - "subtle", - "zeroize", -] - [[package]] name = "encoding_rs" version = "0.8.35" @@ -981,16 +918,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "ff" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" -dependencies = [ - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "fiat-crypto" version = "0.2.9" @@ -1151,7 +1078,6 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", - "zeroize", ] [[package]] @@ -1211,17 +1137,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "h2" version = "0.4.13" @@ -1705,21 +1620,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "k256" -version = "0.13.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" -dependencies = [ - "cfg-if", - "ecdsa", - "elliptic-curve", - "once_cell", - "serdect", - "sha2", - "signature", -] - [[package]] name = "lazy_static" version = "1.5.0" @@ -2483,16 +2383,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac", - "subtle", -] - [[package]] name = "ring" version = "0.17.14" @@ -2671,21 +2561,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "sec1" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" -dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "serdect", - "subtle", - "zeroize", -] - [[package]] name = "security-framework" version = "3.7.0" @@ -2790,16 +2665,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serdect" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177" -dependencies = [ - "base16ct", - "serde", -] - [[package]] name = "sha1" version = "0.10.6" @@ -2853,7 +2718,6 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ - "digest", "rand_core 0.6.4", ] @@ -3067,15 +2931,6 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinystr" version = "0.8.2" @@ -3495,18 +3350,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" -[[package]] -name = "uuid" -version = "1.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" -dependencies = [ - "getrandom 0.4.2", - "js-sys", - "serde_core", - "wasm-bindgen", -] - [[package]] name = "valuable" version = "0.1.1" @@ -3546,28 +3389,7 @@ dependencies = [ [[package]] name = "warzone-protocol" -version = "0.0.38" -dependencies = [ - "base64", - "bincode", - "bip39", - "chacha20poly1305", - "chrono", - "curve25519-dalek", - "ed25519-dalek", - "hex", - "hkdf", - "k256", - "rand 0.8.5", - "serde", - "serde_json", - "sha2", - "thiserror 2.0.18", - "tiny-keccak", - "uuid", - "x25519-dalek", - "zeroize", -] +version = "0.1.0" [[package]] name = "wasi" @@ -4179,6 +4001,28 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "wzp-android" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bytes", + "cc", + "libc", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tracing", + "tracing-subscriber", + "wzp-codec", + "wzp-crypto", + "wzp-fec", + "wzp-proto", + "wzp-transport", +] + [[package]] name = "wzp-client" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9c9d9f3..1daa196 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/wzp-relay", "crates/wzp-client", "crates/wzp-web", + "crates/wzp-android", ] [workspace.package] diff --git a/android/app/build.gradle.kts b/android/app/build.gradle.kts new file mode 100644 index 0000000..c06aaf8 --- /dev/null +++ b/android/app/build.gradle.kts @@ -0,0 +1,64 @@ +plugins { + id("com.android.application") + id("org.jetbrains.kotlin.android") +} + +android { + namespace = "com.wzp.phone" + compileSdk = 34 + + defaultConfig { + applicationId = "com.wzp.phone" + minSdk = 26 // AAudio requires API 26 + targetSdk = 34 + versionCode = 1 + versionName = "0.1.0" + ndk { abiFilters += listOf("arm64-v8a", "armeabi-v7a") } + } + + buildTypes { + release { + isMinifyEnabled = true + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + + kotlinOptions { + jvmTarget = "1.8" + } + + buildFeatures { compose = true } + composeOptions { kotlinCompilerExtensionVersion = "1.5.8" } + + ndkVersion = "26.1.10909125" +} + +// cargo-ndk integration: build the Rust native library for Android targets +tasks.register("cargoNdkBuild") { + workingDir = file("${project.rootDir}/..") + commandLine( + "cargo", "ndk", + "-t", "arm64-v8a", "-t", "armeabi-v7a", + "-o", "${project.projectDir}/src/main/jniLibs", + "build", "--release", "-p", "wzp-android" + ) +} + +tasks.named("preBuild") { dependsOn("cargoNdkBuild") } + +dependencies { + implementation("androidx.core:core-ktx:1.12.0") + implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.7.0") + implementation("androidx.activity:activity-compose:1.8.2") + implementation(platform("androidx.compose:compose-bom:2024.01.00")) + implementation("androidx.compose.ui:ui") + implementation("androidx.compose.material3:material3") +} diff --git a/android/app/proguard-rules.pro b/android/app/proguard-rules.pro new file mode 100644 index 0000000..a9a319c --- /dev/null +++ b/android/app/proguard-rules.pro @@ -0,0 +1,9 @@ +# WZPhone ProGuard rules + +# Keep JNI native methods +-keepclasseswithmembernames class * { + native ; +} + +# Keep the WZP engine bridge class +-keep class com.wzp.phone.engine.** { *; } diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..367de2d --- /dev/null +++ b/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/app/src/main/java/com/wzp/.gitkeep b/android/app/src/main/java/com/wzp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/android/build.gradle.kts b/android/build.gradle.kts new file mode 100644 index 0000000..0836ea8 --- /dev/null +++ b/android/build.gradle.kts @@ -0,0 +1,4 @@ +plugins { + id("com.android.application") version "8.2.0" apply false + id("org.jetbrains.kotlin.android") version "1.9.22" apply false +} diff --git a/android/gradle.properties b/android/gradle.properties new file mode 100644 index 0000000..f0a2e55 --- /dev/null +++ b/android/gradle.properties @@ -0,0 +1,4 @@ +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +android.useAndroidX=true +kotlin.code.style=official +android.nonTransitiveRClass=true diff --git a/android/settings.gradle.kts b/android/settings.gradle.kts new file mode 100644 index 0000000..24fbd0c --- /dev/null +++ b/android/settings.gradle.kts @@ -0,0 +1,17 @@ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolution { + repositories { + google() + mavenCentral() + } +} + +rootProject.name = "WZPhone" +include(":app") diff --git a/crates/wzp-android/Cargo.toml b/crates/wzp-android/Cargo.toml new file mode 100644 index 0000000..cfdbb7b --- /dev/null +++ b/crates/wzp-android/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "wzp-android" +version.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +description = "WarzonePhone Android native VoIP engine — Oboe audio, JNI bridge, call pipeline" + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +wzp-proto = { workspace = true } +wzp-codec = { workspace = true } +wzp-fec = { workspace = true } +wzp-crypto = { workspace = true } +wzp-transport = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +bytes = { workspace = true } +serde = { workspace = true } +serde_json = "1" +thiserror = { workspace = true } +async-trait = { workspace = true } +anyhow = "1" +libc = "0.2" + +[build-dependencies] +cc = "1" diff --git a/crates/wzp-android/build.rs b/crates/wzp-android/build.rs new file mode 100644 index 0000000..a6ae962 --- /dev/null +++ b/crates/wzp-android/build.rs @@ -0,0 +1,20 @@ +fn main() { + let target = std::env::var("TARGET").unwrap_or_default(); + if target.contains("android") { + // Real Oboe build for Android targets + cc::Build::new() + .cpp(true) + .std("c++17") + .file("cpp/oboe_bridge.cpp") + .include("cpp") + .compile("oboe_bridge"); + } else { + // Stub for host builds / testing + cc::Build::new() + .cpp(true) + .std("c++17") + .file("cpp/oboe_stub.cpp") + .include("cpp") + .compile("oboe_bridge"); + } +} diff --git a/crates/wzp-android/cpp/oboe_bridge.cpp b/crates/wzp-android/cpp/oboe_bridge.cpp new file mode 100644 index 0000000..2cb6afd --- /dev/null +++ b/crates/wzp-android/cpp/oboe_bridge.cpp @@ -0,0 +1,278 @@ +// Full Oboe implementation for Android +// This file is compiled only when targeting Android + +#include "oboe_bridge.h" + +#ifdef __ANDROID__ +#include +#include +#include +#include + +#define LOG_TAG "wzp-oboe" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +// --------------------------------------------------------------------------- +// Ring buffer helpers (SPSC, lock-free) +// --------------------------------------------------------------------------- + +static inline int32_t ring_available_read(const wzp_atomic_int* write_idx, + const wzp_atomic_int* read_idx, + int32_t capacity) { + int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_acquire); + int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed); + int32_t avail = w - r; + if (avail < 0) avail += capacity; + return avail; +} + +static inline int32_t ring_available_write(const wzp_atomic_int* write_idx, + const wzp_atomic_int* read_idx, + int32_t capacity) { + return capacity - 1 - ring_available_read(write_idx, read_idx, capacity); +} + +static inline void ring_write(int16_t* buf, int32_t capacity, + wzp_atomic_int* write_idx, const wzp_atomic_int* read_idx, + const int16_t* src, int32_t count) { + int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_relaxed); + for (int32_t i = 0; i < count; i++) { + buf[w] = src[i]; + w++; + if (w >= capacity) w = 0; + } + std::atomic_store_explicit(write_idx, w, std::memory_order_release); +} + +static inline void ring_read(int16_t* buf, int32_t capacity, + const wzp_atomic_int* write_idx, wzp_atomic_int* read_idx, + int16_t* dst, int32_t count) { + int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed); + for (int32_t i = 0; i < count; i++) { + dst[i] = buf[r]; + r++; + if (r >= capacity) r = 0; + } + std::atomic_store_explicit(read_idx, r, std::memory_order_release); +} + +// --------------------------------------------------------------------------- +// Global state +// --------------------------------------------------------------------------- + +static std::shared_ptr g_capture_stream; +static std::shared_ptr g_playout_stream; +static const WzpOboeRings* g_rings = nullptr; +static std::atomic g_running{false}; +static std::atomic g_capture_latency_ms{0.0f}; +static std::atomic g_playout_latency_ms{0.0f}; + +// --------------------------------------------------------------------------- +// Capture callback +// --------------------------------------------------------------------------- + +class CaptureCallback : public oboe::AudioStreamDataCallback { +public: + oboe::DataCallbackResult onAudioReady( + oboe::AudioStream* stream, + void* audioData, + int32_t numFrames) override { + if (!g_running.load(std::std::memory_order_relaxed) || !g_rings) { + return oboe::DataCallbackResult::Stop; + } + + const int16_t* src = static_cast(audioData); + int32_t avail = ring_available_write(g_rings->capture_write_idx, + g_rings->capture_read_idx, + g_rings->capture_capacity); + int32_t to_write = (numFrames < avail) ? numFrames : avail; + if (to_write > 0) { + ring_write(g_rings->capture_buf, g_rings->capture_capacity, + g_rings->capture_write_idx, g_rings->capture_read_idx, + src, to_write); + } + + // Update latency estimate + auto result = stream->calculateLatencyMillis(); + if (result) { + g_capture_latency_ms.store(static_cast(result.value()), + std::std::memory_order_relaxed); + } + + return oboe::DataCallbackResult::Continue; + } +}; + +// --------------------------------------------------------------------------- +// Playout callback +// --------------------------------------------------------------------------- + +class PlayoutCallback : public oboe::AudioStreamDataCallback { +public: + oboe::DataCallbackResult onAudioReady( + oboe::AudioStream* stream, + void* audioData, + int32_t numFrames) override { + if (!g_running.load(std::std::memory_order_relaxed) || !g_rings) { + memset(audioData, 0, numFrames * sizeof(int16_t)); + return oboe::DataCallbackResult::Stop; + } + + int16_t* dst = static_cast(audioData); + int32_t avail = ring_available_read(g_rings->playout_write_idx, + g_rings->playout_read_idx, + g_rings->playout_capacity); + int32_t to_read = (numFrames < avail) ? numFrames : avail; + + if (to_read > 0) { + ring_read(g_rings->playout_buf, g_rings->playout_capacity, + g_rings->playout_write_idx, g_rings->playout_read_idx, + dst, to_read); + } + // Fill remainder with silence on underrun + if (to_read < numFrames) { + memset(dst + to_read, 0, (numFrames - to_read) * sizeof(int16_t)); + } + + // Update latency estimate + auto result = stream->calculateLatencyMillis(); + if (result) { + g_playout_latency_ms.store(static_cast(result.value()), + std::std::memory_order_relaxed); + } + + return oboe::DataCallbackResult::Continue; + } +}; + +static CaptureCallback g_capture_cb; +static PlayoutCallback g_playout_cb; + +// --------------------------------------------------------------------------- +// Public C API +// --------------------------------------------------------------------------- + +int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) { + if (g_running.load(std::std::memory_order_relaxed)) { + LOGW("wzp_oboe_start: already running"); + return -1; + } + + g_rings = rings; + + // Build capture stream + oboe::AudioStreamBuilder captureBuilder; + captureBuilder.setDirection(oboe::Direction::Input) + ->setPerformanceMode(oboe::PerformanceMode::LowLatency) + ->setSharingMode(oboe::SharingMode::Exclusive) + ->setFormat(oboe::AudioFormat::I16) + ->setChannelCount(config->channel_count) + ->setSampleRate(config->sample_rate) + ->setFramesPerDataCallback(config->frames_per_burst) + ->setInputPreset(oboe::InputPreset::VoiceCommunication) + ->setDataCallback(&g_capture_cb); + + oboe::Result result = captureBuilder.openStream(g_capture_stream); + if (result != oboe::Result::OK) { + LOGE("Failed to open capture stream: %s", oboe::convertToText(result)); + return -2; + } + + // Build playout stream + oboe::AudioStreamBuilder playoutBuilder; + playoutBuilder.setDirection(oboe::Direction::Output) + ->setPerformanceMode(oboe::PerformanceMode::LowLatency) + ->setSharingMode(oboe::SharingMode::Exclusive) + ->setFormat(oboe::AudioFormat::I16) + ->setChannelCount(config->channel_count) + ->setSampleRate(config->sample_rate) + ->setFramesPerDataCallback(config->frames_per_burst) + ->setUsage(oboe::Usage::VoiceCommunication) + ->setDataCallback(&g_playout_cb); + + result = playoutBuilder.openStream(g_playout_stream); + if (result != oboe::Result::OK) { + LOGE("Failed to open playout stream: %s", oboe::convertToText(result)); + g_capture_stream->close(); + g_capture_stream.reset(); + return -3; + } + + g_running.store(true, std::std::memory_order_release); + + // Start both streams + result = g_capture_stream->requestStart(); + if (result != oboe::Result::OK) { + LOGE("Failed to start capture: %s", oboe::convertToText(result)); + g_running.store(false, std::std::memory_order_release); + g_capture_stream->close(); + g_playout_stream->close(); + g_capture_stream.reset(); + g_playout_stream.reset(); + return -4; + } + + result = g_playout_stream->requestStart(); + if (result != oboe::Result::OK) { + LOGE("Failed to start playout: %s", oboe::convertToText(result)); + g_running.store(false, std::std::memory_order_release); + g_capture_stream->requestStop(); + g_capture_stream->close(); + g_playout_stream->close(); + g_capture_stream.reset(); + g_playout_stream.reset(); + return -5; + } + + LOGI("Oboe started: sr=%d burst=%d ch=%d", + config->sample_rate, config->frames_per_burst, config->channel_count); + return 0; +} + +void wzp_oboe_stop(void) { + g_running.store(false, std::std::memory_order_release); + + if (g_capture_stream) { + g_capture_stream->requestStop(); + g_capture_stream->close(); + g_capture_stream.reset(); + } + if (g_playout_stream) { + g_playout_stream->requestStop(); + g_playout_stream->close(); + g_playout_stream.reset(); + } + + g_rings = nullptr; + LOGI("Oboe stopped"); +} + +float wzp_oboe_capture_latency_ms(void) { + return g_capture_latency_ms.load(std::std::memory_order_relaxed); +} + +float wzp_oboe_playout_latency_ms(void) { + return g_playout_latency_ms.load(std::std::memory_order_relaxed); +} + +int wzp_oboe_is_running(void) { + return g_running.load(std::std::memory_order_relaxed) ? 1 : 0; +} + +#else +// Non-Android fallback — should not be reached; oboe_stub.cpp is used instead. +// Provide empty implementations just in case. + +int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) { + (void)config; (void)rings; + return -99; +} + +void wzp_oboe_stop(void) {} +float wzp_oboe_capture_latency_ms(void) { return 0.0f; } +float wzp_oboe_playout_latency_ms(void) { return 0.0f; } +int wzp_oboe_is_running(void) { return 0; } + +#endif // __ANDROID__ diff --git a/crates/wzp-android/cpp/oboe_bridge.h b/crates/wzp-android/cpp/oboe_bridge.h new file mode 100644 index 0000000..8c2f143 --- /dev/null +++ b/crates/wzp-android/cpp/oboe_bridge.h @@ -0,0 +1,43 @@ +#ifndef WZP_OBOE_BRIDGE_H +#define WZP_OBOE_BRIDGE_H + +#include + +#ifdef __cplusplus +#include +typedef std::atomic wzp_atomic_int; +extern "C" { +#else +#include +typedef atomic_int wzp_atomic_int; +#endif + +typedef struct { + int32_t sample_rate; + int32_t frames_per_burst; + int32_t channel_count; +} WzpOboeConfig; + +typedef struct { + int16_t* capture_buf; + int32_t capture_capacity; + wzp_atomic_int* capture_write_idx; + wzp_atomic_int* capture_read_idx; + + int16_t* playout_buf; + int32_t playout_capacity; + wzp_atomic_int* playout_write_idx; + wzp_atomic_int* playout_read_idx; +} WzpOboeRings; + +int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings); +void wzp_oboe_stop(void); +float wzp_oboe_capture_latency_ms(void); +float wzp_oboe_playout_latency_ms(void); +int wzp_oboe_is_running(void); + +#ifdef __cplusplus +} +#endif + +#endif // WZP_OBOE_BRIDGE_H diff --git a/crates/wzp-android/cpp/oboe_stub.cpp b/crates/wzp-android/cpp/oboe_stub.cpp new file mode 100644 index 0000000..6792259 --- /dev/null +++ b/crates/wzp-android/cpp/oboe_stub.cpp @@ -0,0 +1,27 @@ +// Stub implementation for non-Android host builds (testing, cargo check, etc.) + +#include "oboe_bridge.h" +#include + +int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) { + (void)config; + (void)rings; + fprintf(stderr, "wzp_oboe_start: stub (not on Android)\n"); + return 0; +} + +void wzp_oboe_stop(void) { + fprintf(stderr, "wzp_oboe_stop: stub (not on Android)\n"); +} + +float wzp_oboe_capture_latency_ms(void) { + return 0.0f; +} + +float wzp_oboe_playout_latency_ms(void) { + return 0.0f; +} + +int wzp_oboe_is_running(void) { + return 0; +} diff --git a/crates/wzp-android/src/audio_android.rs b/crates/wzp-android/src/audio_android.rs new file mode 100644 index 0000000..db58046 --- /dev/null +++ b/crates/wzp-android/src/audio_android.rs @@ -0,0 +1,424 @@ +//! Lock-free SPSC ring buffer audio backend for Android (Oboe). +//! +//! The ring buffers are shared between Rust and C++: the Oboe callbacks +//! (running on a high-priority audio thread) read/write directly into +//! the buffers via atomic indices, while the Rust codec thread on the +//! other side does the same. + +use std::sync::atomic::{AtomicI32, Ordering}; + +use tracing::info; +#[allow(unused_imports)] +use tracing::warn; + +/// Number of samples per 20 ms frame at 48 kHz mono. +pub const FRAME_SAMPLES: usize = 960; + +/// Default ring buffer capacity: 8 frames = 160 ms at 48 kHz. +const RING_CAPACITY: usize = 7680; + +// --------------------------------------------------------------------------- +// FFI declarations matching oboe_bridge.h +// --------------------------------------------------------------------------- + +#[repr(C)] +#[allow(non_snake_case)] +struct WzpOboeConfig { + sample_rate: i32, + frames_per_burst: i32, + channel_count: i32, +} + +#[repr(C)] +#[allow(non_snake_case)] +struct WzpOboeRings { + capture_buf: *mut i16, + capture_capacity: i32, + capture_write_idx: *mut AtomicI32, + capture_read_idx: *mut AtomicI32, + + playout_buf: *mut i16, + playout_capacity: i32, + playout_write_idx: *mut AtomicI32, + playout_read_idx: *mut AtomicI32, +} + +unsafe impl Send for WzpOboeRings {} +unsafe impl Sync for WzpOboeRings {} + +unsafe extern "C" { + fn wzp_oboe_start(config: *const WzpOboeConfig, rings: *const WzpOboeRings) -> i32; + fn wzp_oboe_stop(); + fn wzp_oboe_capture_latency_ms() -> f32; + fn wzp_oboe_playout_latency_ms() -> f32; + fn wzp_oboe_is_running() -> i32; +} + +// --------------------------------------------------------------------------- +// SPSC Ring Buffer +// --------------------------------------------------------------------------- + +/// Single-producer single-consumer lock-free ring buffer. +/// +/// The producer calls `write()` and the consumer calls `read()`. +/// Atomics use acquire/release ordering to ensure correct visibility +/// across the Oboe audio thread and the Rust codec thread. +pub struct RingBuffer { + buf: Vec, + capacity: usize, + write_idx: AtomicI32, + read_idx: AtomicI32, +} + +impl RingBuffer { + /// Create a new ring buffer with the given capacity (in samples). + /// + /// The actual usable capacity is `capacity - 1` to distinguish + /// full from empty. + pub fn new(capacity: usize) -> Self { + Self { + buf: vec![0i16; capacity], + capacity, + write_idx: AtomicI32::new(0), + read_idx: AtomicI32::new(0), + } + } + + /// Number of samples available to read. + pub fn available_read(&self) -> usize { + let w = self.write_idx.load(Ordering::Acquire); + let r = self.read_idx.load(Ordering::Relaxed); + let avail = w - r; + if avail < 0 { + (avail + self.capacity as i32) as usize + } else { + avail as usize + } + } + + /// Number of samples that can be written before the buffer is full. + pub fn available_write(&self) -> usize { + self.capacity - 1 - self.available_read() + } + + /// Write samples into the ring buffer (producer side). + /// + /// Returns the number of samples actually written (may be less than + /// `data.len()` if the buffer is nearly full). + pub fn write(&self, data: &[i16]) -> usize { + let avail = self.available_write(); + let count = data.len().min(avail); + if count == 0 { + return 0; + } + + let mut w = self.write_idx.load(Ordering::Relaxed) as usize; + let cap = self.capacity; + let buf_ptr = self.buf.as_ptr() as *mut i16; + + for i in 0..count { + // SAFETY: w is always in [0, capacity) and we are the sole producer. + unsafe { + *buf_ptr.add(w) = data[i]; + } + w += 1; + if w >= cap { + w = 0; + } + } + + self.write_idx.store(w as i32, Ordering::Release); + count + } + + /// Read samples from the ring buffer (consumer side). + /// + /// Returns the number of samples actually read (may be less than + /// `out.len()` if the buffer doesn't have enough data). + pub fn read(&self, out: &mut [i16]) -> usize { + let avail = self.available_read(); + let count = out.len().min(avail); + if count == 0 { + return 0; + } + + let mut r = self.read_idx.load(Ordering::Relaxed) as usize; + let cap = self.capacity; + let buf_ptr = self.buf.as_ptr(); + + for i in 0..count { + // SAFETY: r is always in [0, capacity) and we are the sole consumer. + unsafe { + out[i] = *buf_ptr.add(r); + } + r += 1; + if r >= cap { + r = 0; + } + } + + self.read_idx.store(r as i32, Ordering::Release); + count + } + + /// Get a raw pointer to the buffer data (for FFI). + fn buf_ptr(&self) -> *mut i16 { + self.buf.as_ptr() as *mut i16 + } + + /// Get a raw pointer to the write index atomic (for FFI). + fn write_idx_ptr(&self) -> *mut AtomicI32 { + &self.write_idx as *const AtomicI32 as *mut AtomicI32 + } + + /// Get a raw pointer to the read index atomic (for FFI). + fn read_idx_ptr(&self) -> *mut AtomicI32 { + &self.read_idx as *const AtomicI32 as *mut AtomicI32 + } +} + +// SAFETY: The ring buffer is designed for SPSC use where producer and consumer +// are on different threads. The atomic indices provide the synchronization. +unsafe impl Send for RingBuffer {} +unsafe impl Sync for RingBuffer {} + +// --------------------------------------------------------------------------- +// Oboe Backend +// --------------------------------------------------------------------------- + +/// Oboe-based audio backend for Android. +/// +/// Owns two SPSC ring buffers (capture and playout) that are shared with +/// the C++ Oboe callbacks via raw pointers. The Oboe callbacks run on +/// high-priority audio threads managed by the Android audio system. +pub struct OboeBackend { + capture_ring: RingBuffer, + playout_ring: RingBuffer, + started: bool, +} + +impl OboeBackend { + /// Create a new backend with default ring buffer sizes (160 ms each). + pub fn new() -> Self { + Self { + capture_ring: RingBuffer::new(RING_CAPACITY), + playout_ring: RingBuffer::new(RING_CAPACITY), + started: false, + } + } + + /// Start Oboe audio streams. + /// + /// This sets up the ring buffer pointers and calls into the C++ layer + /// to open and start the capture and playout Oboe streams. + pub fn start(&mut self) -> Result<(), anyhow::Error> { + if self.started { + return Ok(()); + } + + let config = WzpOboeConfig { + sample_rate: 48_000, + frames_per_burst: FRAME_SAMPLES as i32, + channel_count: 1, + }; + + let rings = WzpOboeRings { + capture_buf: self.capture_ring.buf_ptr(), + capture_capacity: self.capture_ring.capacity as i32, + capture_write_idx: self.capture_ring.write_idx_ptr(), + capture_read_idx: self.capture_ring.read_idx_ptr(), + + playout_buf: self.playout_ring.buf_ptr(), + playout_capacity: self.playout_ring.capacity as i32, + playout_write_idx: self.playout_ring.write_idx_ptr(), + playout_read_idx: self.playout_ring.read_idx_ptr(), + }; + + let ret = unsafe { wzp_oboe_start(&config, &rings) }; + if ret != 0 { + return Err(anyhow::anyhow!("wzp_oboe_start failed with code {}", ret)); + } + + self.started = true; + info!("Oboe backend started"); + Ok(()) + } + + /// Stop Oboe audio streams. + pub fn stop(&mut self) { + if !self.started { + return; + } + unsafe { wzp_oboe_stop() }; + self.started = false; + info!("Oboe backend stopped"); + } + + /// Read captured audio samples from the capture ring buffer. + /// + /// Returns the number of samples actually read. The caller should + /// provide a buffer of at least `FRAME_SAMPLES` (960) samples. + pub fn read_capture(&self, out: &mut [i16]) -> usize { + self.capture_ring.read(out) + } + + /// Write audio samples to the playout ring buffer. + /// + /// Returns the number of samples actually written. + pub fn write_playout(&self, samples: &[i16]) -> usize { + self.playout_ring.write(samples) + } + + /// Get the current capture latency in milliseconds (from Oboe). + #[allow(unused)] + pub fn capture_latency_ms(&self) -> f32 { + unsafe { wzp_oboe_capture_latency_ms() } + } + + /// Get the current playout latency in milliseconds (from Oboe). + #[allow(unused)] + pub fn playout_latency_ms(&self) -> f32 { + unsafe { wzp_oboe_playout_latency_ms() } + } + + /// Check if the Oboe streams are currently running. + #[allow(unused)] + pub fn is_running(&self) -> bool { + unsafe { wzp_oboe_is_running() != 0 } + } +} + +impl Drop for OboeBackend { + fn drop(&mut self) { + self.stop(); + } +} + +// --------------------------------------------------------------------------- +// Thread affinity / priority helpers +// --------------------------------------------------------------------------- + +/// Pin the current thread to the highest-numbered CPU cores (big cores on +/// ARM big.LITTLE architectures). Falls back silently on failure. +#[allow(unused)] +pub fn pin_to_big_core() { + #[cfg(target_os = "android")] + { + unsafe { + let num_cpus = libc::sysconf(libc::_SC_NPROCESSORS_ONLN); + if num_cpus <= 0 { + warn!("pin_to_big_core: could not determine CPU count"); + return; + } + let num_cpus = num_cpus as usize; + + // Target the upper half of CPUs (big cores on most big.LITTLE SoCs) + let start = num_cpus / 2; + let mut set: libc::cpu_set_t = std::mem::zeroed(); + libc::CPU_ZERO(&mut set); + for cpu in start..num_cpus { + libc::CPU_SET(cpu, &mut set); + } + + let ret = libc::sched_setaffinity( + 0, // current thread + std::mem::size_of::(), + &set, + ); + if ret != 0 { + warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error()); + } else { + info!(start, num_cpus, "pinned to big cores"); + } + } + } + #[cfg(not(target_os = "android"))] + { + // No-op on non-Android + } +} + +/// Attempt to set SCHED_FIFO real-time priority for the current thread. +/// Falls back silently on failure (requires appropriate permissions on Android). +#[allow(unused)] +pub fn set_realtime_priority() { + #[cfg(target_os = "android")] + { + unsafe { + let param = libc::sched_param { + sched_priority: 2, // Low RT priority — enough for audio, safe + }; + let ret = libc::sched_setscheduler(0, libc::SCHED_FIFO, ¶m); + if ret != 0 { + warn!( + "sched_setscheduler(SCHED_FIFO) failed: {}", + std::io::Error::last_os_error() + ); + } else { + info!("set SCHED_FIFO priority 2"); + } + } + } + #[cfg(not(target_os = "android"))] + { + // No-op on non-Android + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ring_buffer_write_read() { + let ring = RingBuffer::new(16); + let data = [1i16, 2, 3, 4, 5]; + assert_eq!(ring.write(&data), 5); + assert_eq!(ring.available_read(), 5); + + let mut out = [0i16; 5]; + assert_eq!(ring.read(&mut out), 5); + assert_eq!(out, [1, 2, 3, 4, 5]); + assert_eq!(ring.available_read(), 0); + } + + #[test] + fn ring_buffer_wraparound() { + let ring = RingBuffer::new(8); + let data = [10i16, 20, 30, 40, 50, 60]; // 6 samples, capacity 8 (usable 7) + assert_eq!(ring.write(&data), 6); + + let mut out = [0i16; 4]; + assert_eq!(ring.read(&mut out), 4); + assert_eq!(out, [10, 20, 30, 40]); + + // Now write more, which should wrap around + let data2 = [70i16, 80, 90, 100]; + assert_eq!(ring.write(&data2), 4); + + let mut out2 = [0i16; 6]; + assert_eq!(ring.read(&mut out2), 6); + assert_eq!(out2, [50, 60, 70, 80, 90, 100]); + } + + #[test] + fn ring_buffer_full() { + let ring = RingBuffer::new(4); // usable capacity = 3 + let data = [1i16, 2, 3, 4, 5]; + assert_eq!(ring.write(&data), 3); // Only 3 fit + assert_eq!(ring.available_write(), 0); + } + + #[test] + fn oboe_backend_stub_start_stop() { + let mut backend = OboeBackend::new(); + backend.start().expect("stub start should succeed"); + assert!(backend.started); + backend.stop(); + assert!(!backend.started); + } +} diff --git a/crates/wzp-android/src/commands.rs b/crates/wzp-android/src/commands.rs new file mode 100644 index 0000000..1790553 --- /dev/null +++ b/crates/wzp-android/src/commands.rs @@ -0,0 +1,15 @@ +//! Engine commands sent from the JNI/UI thread to the engine. + +use wzp_proto::QualityProfile; + +/// Commands that can be sent to the running engine. +pub enum EngineCommand { + /// Mute or unmute the microphone. + SetMute(bool), + /// Enable or disable speaker (loudspeaker) mode. + SetSpeaker(bool), + /// Force a specific quality profile (overrides adaptive logic). + ForceProfile(QualityProfile), + /// Stop the call and shut down the engine. + Stop, +} diff --git a/crates/wzp-android/src/engine.rs b/crates/wzp-android/src/engine.rs new file mode 100644 index 0000000..eba351e --- /dev/null +++ b/crates/wzp-android/src/engine.rs @@ -0,0 +1,357 @@ +//! Engine orchestrator — manages the call lifecycle. +//! +//! The engine owns: +//! - The Oboe audio backend (start/stop) +//! - A codec thread running the `Pipeline` +//! - A tokio runtime for async network I/O +//! - Command channel for control from the JNI/UI thread + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use tracing::{error, info, warn}; +use wzp_proto::QualityProfile; + +use crate::audio_android::{OboeBackend, FRAME_SAMPLES}; +use crate::commands::EngineCommand; +use crate::pipeline::Pipeline; +use crate::stats::{CallState, CallStats}; + +/// Configuration to start a call. +pub struct CallStartConfig { + /// Initial quality profile. + pub profile: QualityProfile, + /// Relay server address (host:port). + pub relay_addr: String, + /// Authentication token for the relay. + pub auth_token: Vec, + /// 32-byte identity seed for key derivation. + pub identity_seed: [u8; 32], +} + +impl Default for CallStartConfig { + fn default() -> Self { + Self { + profile: QualityProfile::GOOD, + relay_addr: String::new(), + auth_token: Vec::new(), + identity_seed: [0u8; 32], + } + } +} + +/// Shared state between the engine owner and background threads. +struct EngineState { + running: AtomicBool, + muted: AtomicBool, + speaker: AtomicBool, + stats: Mutex, + command_tx: std::sync::mpsc::Sender, + command_rx: Mutex>>, +} + +/// The WarzonePhone Android engine. +/// +/// Manages the entire call pipeline: audio capture/playout via Oboe, +/// codec encode/decode, FEC, jitter buffer, and network transport. +/// +/// Thread model: +/// - **UI/JNI thread**: calls `start_call`, `stop_call`, `set_mute`, etc. +/// - **Codec thread**: runs `Pipeline` encode/decode loop, reads/writes ring buffers +/// - **Tokio runtime** (2 worker threads): async network send/recv +pub struct WzpEngine { + state: Arc, + codec_thread: Option>, + #[allow(unused)] + tokio_runtime: Option, + call_start: Option, +} + +impl WzpEngine { + /// Create a new idle engine. + pub fn new() -> Self { + let (tx, rx) = std::sync::mpsc::channel(); + let state = Arc::new(EngineState { + running: AtomicBool::new(false), + muted: AtomicBool::new(false), + speaker: AtomicBool::new(false), + stats: Mutex::new(CallStats::default()), + command_tx: tx, + command_rx: Mutex::new(Some(rx)), + }); + + Self { + state, + codec_thread: None, + tokio_runtime: None, + call_start: None, + } + } + + /// Start a call with the given configuration. + /// + /// This creates the tokio runtime, starts the Oboe audio backend, + /// and spawns the codec thread. + pub fn start_call(&mut self, config: CallStartConfig) -> Result<(), anyhow::Error> { + if self.state.running.load(Ordering::Acquire) { + return Err(anyhow::anyhow!("call already active")); + } + + // Update state + { + let mut stats = self.state.stats.lock().unwrap(); + *stats = CallStats { + state: CallState::Connecting, + ..Default::default() + }; + } + + // Create tokio runtime with 2 worker threads + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .thread_name("wzp-net") + .enable_all() + .build()?; + + // Create async channels for network send/recv + let (send_tx, mut _send_rx) = tokio::sync::mpsc::channel::>(64); + let (_recv_tx, mut recv_rx) = tokio::sync::mpsc::channel::>(64); + + // Spawn network tasks (placeholder — will use wzp-transport) + let _relay_addr = config.relay_addr.clone(); + runtime.spawn(async move { + // Network send task: reads from send_rx, sends via transport + // This will be implemented when wzp-transport Android support is added + while let Some(_packet) = _send_rx.recv().await { + // TODO: send via wzp-transport + } + }); + + let recv_tx_clone = _recv_tx.clone(); + runtime.spawn(async move { + // Network recv task: reads from transport, writes to recv_rx + // This will be implemented when wzp-transport Android support is added + let _tx = recv_tx_clone; + // TODO: recv from wzp-transport and forward + }); + + // Take the command receiver (it can only be taken once) + let command_rx = self + .state + .command_rx + .lock() + .unwrap() + .take() + .ok_or_else(|| anyhow::anyhow!("command receiver already taken"))?; + + // Start the codec thread + let state = self.state.clone(); + let profile = config.profile; + let codec_thread = std::thread::Builder::new() + .name("wzp-codec".into()) + .spawn(move || { + // Pin to big cores and set RT priority on Android + crate::audio_android::pin_to_big_core(); + crate::audio_android::set_realtime_priority(); + + // Create audio backend + let mut audio = OboeBackend::new(); + if let Err(e) = audio.start() { + error!("failed to start audio: {e}"); + state.running.store(false, Ordering::Release); + return; + } + + // Create pipeline + let mut pipeline = match Pipeline::new(profile) { + Ok(p) => p, + Err(e) => { + error!("failed to create pipeline: {e}"); + audio.stop(); + state.running.store(false, Ordering::Release); + return; + } + }; + + state.running.store(true, Ordering::Release); + { + let mut stats = state.stats.lock().unwrap(); + stats.state = CallState::Active; + } + + info!("codec thread started"); + + let mut capture_buf = vec![0i16; FRAME_SAMPLES]; + #[allow(unused_assignments)] + let mut recv_buf: Vec = Vec::new(); + + // Main codec loop: 20ms per iteration + let frame_duration = std::time::Duration::from_millis(20); + + while state.running.load(Ordering::Relaxed) { + let loop_start = Instant::now(); + + // Process commands (non-blocking) + while let Ok(cmd) = command_rx.try_recv() { + match cmd { + EngineCommand::SetMute(m) => { + state.muted.store(m, Ordering::Relaxed); + info!(muted = m, "mute toggled"); + } + EngineCommand::SetSpeaker(s) => { + state.speaker.store(s, Ordering::Relaxed); + info!(speaker = s, "speaker toggled"); + } + EngineCommand::ForceProfile(p) => { + pipeline.force_profile(p); + info!(?p, "profile forced"); + } + EngineCommand::Stop => { + info!("stop command received"); + state.running.store(false, Ordering::Release); + break; + } + } + } + + if !state.running.load(Ordering::Relaxed) { + break; + } + + // --- Capture → Encode → Send --- + let captured = audio.read_capture(&mut capture_buf); + if captured >= FRAME_SAMPLES { + let muted = state.muted.load(Ordering::Relaxed); + if let Some(encoded) = pipeline.encode_frame(&capture_buf, muted) { + // Send to network (best-effort) + let _ = send_tx.try_send(encoded); + } + } + + // --- Recv → Decode → Playout --- + // Drain received packets from the network channel + while let Ok(data) = recv_rx.try_recv() { + recv_buf = data; + // Deserialize the packet and feed to pipeline + // For now, feed raw bytes — full MediaPacket deserialization + // will be added with the transport integration + let _ = &recv_buf; // suppress unused warning + } + + // Decode from jitter buffer + if let Some(pcm) = pipeline.decode_frame() { + audio.write_playout(&pcm); + } + + // --- Update stats --- + { + let pstats = pipeline.stats(); + let mut stats = state.stats.lock().unwrap(); + stats.frames_encoded = pstats.frames_encoded; + stats.frames_decoded = pstats.frames_decoded; + stats.underruns = pstats.underruns; + stats.jitter_buffer_depth = pstats.jitter_depth; + stats.quality_tier = pstats.quality_tier; + } + + // Sleep for remainder of the 20ms frame period + let elapsed = loop_start.elapsed(); + if elapsed < frame_duration { + std::thread::sleep(frame_duration - elapsed); + } + } + + // Cleanup + audio.stop(); + { + let mut stats = state.stats.lock().unwrap(); + stats.state = CallState::Closed; + } + info!("codec thread exited"); + })?; + + self.codec_thread = Some(codec_thread); + self.tokio_runtime = Some(runtime); + self.call_start = Some(Instant::now()); + + info!("call started"); + Ok(()) + } + + /// Stop the current call and clean up all resources. + pub fn stop_call(&mut self) { + if !self.state.running.load(Ordering::Acquire) { + return; + } + + // Signal stop + self.state.running.store(false, Ordering::Release); + let _ = self.state.command_tx.send(EngineCommand::Stop); + + // Join codec thread + if let Some(handle) = self.codec_thread.take() { + if let Err(e) = handle.join() { + warn!("codec thread panicked: {e:?}"); + } + } + + // Shut down tokio runtime + if let Some(rt) = self.tokio_runtime.take() { + rt.shutdown_timeout(std::time::Duration::from_secs(2)); + } + + self.call_start = None; + info!("call stopped"); + } + + /// Set microphone mute state. + pub fn set_mute(&self, muted: bool) { + let _ = self.state.command_tx.send(EngineCommand::SetMute(muted)); + } + + /// Set speaker (loudspeaker) mode. + #[allow(unused)] + pub fn set_speaker(&self, enabled: bool) { + let _ = self + .state + .command_tx + .send(EngineCommand::SetSpeaker(enabled)); + } + + /// Force a specific quality profile (overrides adaptive logic). + #[allow(unused)] + pub fn force_profile(&self, profile: QualityProfile) { + let _ = self + .state + .command_tx + .send(EngineCommand::ForceProfile(profile)); + } + + /// Get a snapshot of the current call statistics. + pub fn get_stats(&self) -> CallStats { + let mut stats = self.state.stats.lock().unwrap().clone(); + // Update duration from wall clock + if let Some(start) = self.call_start { + stats.duration_secs = start.elapsed().as_secs_f64(); + } + stats + } + + /// Check if a call is currently active. + pub fn is_active(&self) -> bool { + self.state.running.load(Ordering::Acquire) + } + + /// Destroy the engine, stopping any active call. + pub fn destroy(mut self) { + self.stop_call(); + info!("engine destroyed"); + } +} + +impl Drop for WzpEngine { + fn drop(&mut self) { + self.stop_call(); + } +} diff --git a/crates/wzp-android/src/lib.rs b/crates/wzp-android/src/lib.rs new file mode 100644 index 0000000..216f0f7 --- /dev/null +++ b/crates/wzp-android/src/lib.rs @@ -0,0 +1,17 @@ +//! WarzonePhone Android native VoIP engine. +//! +//! Provides: +//! - Oboe audio backend with lock-free SPSC ring buffers +//! - Engine orchestrator managing call lifecycle +//! - Codec pipeline thread (encode/decode/FEC/jitter) +//! - Call statistics and command interface +//! +//! On non-Android targets, the Oboe C++ layer compiles as a stub, +//! allowing `cargo check` and unit tests on the host. + +pub mod audio_android; +pub mod commands; +pub mod engine; +pub mod pipeline; +pub mod stats; +// pub mod jni_bridge; // Added later by Agent 4 diff --git a/crates/wzp-android/src/pipeline.rs b/crates/wzp-android/src/pipeline.rs new file mode 100644 index 0000000..ea49223 --- /dev/null +++ b/crates/wzp-android/src/pipeline.rs @@ -0,0 +1,224 @@ +//! Codec pipeline — encode/decode with FEC and jitter buffer. +//! +//! Runs on a dedicated thread, processing 20 ms frames at 48 kHz. +//! The pipeline is NOT Send/Sync (Opus encoder state) — it is owned +//! exclusively by the codec thread. + +use tracing::{debug, warn}; +use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder}; +use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder}; +use wzp_proto::jitter::{JitterBuffer, PlayoutResult}; +use wzp_proto::quality::AdaptiveQualityController; +use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder}; +use wzp_proto::traits::QualityController; +use wzp_proto::{MediaPacket, QualityProfile}; + +use crate::audio_android::FRAME_SAMPLES; + +/// Maximum encoded frame size (Opus worst case at highest bitrate). +const MAX_ENCODED_BYTES: usize = 1275; + +/// Pipeline statistics snapshot. +#[derive(Clone, Debug, Default)] +pub struct PipelineStats { + pub frames_encoded: u64, + pub frames_decoded: u64, + pub underruns: u64, + pub jitter_depth: usize, + pub quality_tier: u8, +} + +/// The codec pipeline: encode, FEC, jitter buffer, decode. +/// +/// This struct is owned by the codec thread and not shared. +pub struct Pipeline { + encoder: AdaptiveEncoder, + decoder: AdaptiveDecoder, + fec_encoder: RaptorQFecEncoder, + fec_decoder: RaptorQFecDecoder, + jitter_buffer: JitterBuffer, + quality_ctrl: AdaptiveQualityController, + // Pre-allocated scratch buffers + capture_buf: Vec, + #[allow(dead_code)] + playout_buf: Vec, + encode_out: Vec, + // Stats counters + frames_encoded: u64, + frames_decoded: u64, + underruns: u64, +} + +impl Pipeline { + /// Create a new pipeline configured for the given quality profile. + pub fn new(profile: QualityProfile) -> Result { + let encoder = AdaptiveEncoder::new(profile) + .map_err(|e| anyhow::anyhow!("encoder init: {e}"))?; + let decoder = AdaptiveDecoder::new(profile) + .map_err(|e| anyhow::anyhow!("decoder init: {e}"))?; + let fec_encoder = + RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize); + let fec_decoder = + RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize); + let jitter_buffer = JitterBuffer::new(10, 250, 3); + let quality_ctrl = AdaptiveQualityController::new(); + + Ok(Self { + encoder, + decoder, + fec_encoder, + fec_decoder, + jitter_buffer, + quality_ctrl, + capture_buf: vec![0i16; FRAME_SAMPLES], + playout_buf: vec![0i16; FRAME_SAMPLES], + encode_out: vec![0u8; MAX_ENCODED_BYTES], + frames_encoded: 0, + frames_decoded: 0, + underruns: 0, + }) + } + + /// Encode a PCM frame into a compressed packet. + /// + /// If `muted` is true, a silence frame is encoded (all zeros). + /// Returns the encoded bytes, or `None` on encoder error. + pub fn encode_frame(&mut self, pcm: &[i16], muted: bool) -> Option> { + let input = if muted { + // Zero the capture buffer for silence + for s in self.capture_buf.iter_mut() { + *s = 0; + } + &self.capture_buf[..] + } else { + pcm + }; + + match self.encoder.encode(input, &mut self.encode_out) { + Ok(n) => { + self.frames_encoded += 1; + let encoded = self.encode_out[..n].to_vec(); + + // Feed into FEC encoder + if let Err(e) = self.fec_encoder.add_source_symbol(&encoded) { + warn!("FEC encode error: {e}"); + } + + Some(encoded) + } + Err(e) => { + warn!("encode error: {e}"); + None + } + } + } + + /// Feed a received media packet into the jitter buffer. + pub fn feed_packet(&mut self, packet: MediaPacket) { + // Feed FEC symbols if present + let header = &packet.header; + if header.fec_block != 0 || header.fec_symbol != 0 { + let is_repair = header.is_repair; + if let Err(e) = self.fec_decoder.add_symbol( + header.fec_block, + header.fec_symbol, + is_repair, + &packet.payload, + ) { + debug!("FEC symbol feed error: {e}"); + } + } + + self.jitter_buffer.push(packet); + } + + /// Decode the next frame from the jitter buffer. + /// + /// Returns decoded PCM samples, or `None` if the buffer is not ready. + pub fn decode_frame(&mut self) -> Option> { + match self.jitter_buffer.pop() { + PlayoutResult::Packet(pkt) => { + let mut pcm = vec![0i16; FRAME_SAMPLES]; + match self.decoder.decode(&pkt.payload, &mut pcm) { + Ok(n) => { + self.frames_decoded += 1; + pcm.truncate(n); + Some(pcm) + } + Err(e) => { + warn!("decode error: {e}"); + // Attempt PLC + self.generate_plc() + } + } + } + PlayoutResult::Missing { seq } => { + debug!(seq, "jitter buffer: missing packet, generating PLC"); + self.generate_plc() + } + PlayoutResult::NotReady => { + self.underruns += 1; + None + } + } + } + + /// Generate packet loss concealment output. + fn generate_plc(&mut self) -> Option> { + let mut pcm = vec![0i16; FRAME_SAMPLES]; + match self.decoder.decode_lost(&mut pcm) { + Ok(n) => { + self.frames_decoded += 1; + pcm.truncate(n); + Some(pcm) + } + Err(e) => { + warn!("PLC error: {e}"); + None + } + } + } + + /// Feed a quality report into the adaptive quality controller. + /// + /// Returns a new profile if a tier transition occurred. + #[allow(unused)] + pub fn observe_quality( + &mut self, + report: &wzp_proto::QualityReport, + ) -> Option { + let new_profile = self.quality_ctrl.observe(report); + if let Some(ref profile) = new_profile { + if let Err(e) = self.encoder.set_profile(*profile) { + warn!("encoder set_profile error: {e}"); + } + if let Err(e) = self.decoder.set_profile(*profile) { + warn!("decoder set_profile error: {e}"); + } + } + new_profile + } + + /// Force a specific quality profile. + #[allow(unused)] + pub fn force_profile(&mut self, profile: QualityProfile) { + self.quality_ctrl.force_profile(profile); + if let Err(e) = self.encoder.set_profile(profile) { + warn!("encoder set_profile error: {e}"); + } + if let Err(e) = self.decoder.set_profile(profile) { + warn!("decoder set_profile error: {e}"); + } + } + + /// Get current pipeline statistics. + pub fn stats(&self) -> PipelineStats { + PipelineStats { + frames_encoded: self.frames_encoded, + frames_decoded: self.frames_decoded, + underruns: self.underruns, + jitter_depth: self.jitter_buffer.stats().current_depth, + quality_tier: self.quality_ctrl.tier() as u8, + } + } +} diff --git a/crates/wzp-android/src/stats.rs b/crates/wzp-android/src/stats.rs new file mode 100644 index 0000000..cc480cc --- /dev/null +++ b/crates/wzp-android/src/stats.rs @@ -0,0 +1,42 @@ +//! Call statistics for the Android engine. + +/// State of the call. +#[derive(Clone, Debug, Default, serde::Serialize, PartialEq, Eq)] +pub enum CallState { + /// Engine is idle, no active call. + #[default] + Idle, + /// Establishing connection to the relay. + Connecting, + /// Call is active with audio flowing. + Active, + /// Temporarily lost connection, attempting to recover. + Reconnecting, + /// Call has ended. + Closed, +} + +/// Aggregated call statistics, serializable for JNI bridge. +#[derive(Clone, Debug, Default, serde::Serialize)] +pub struct CallStats { + /// Current call state. + pub state: CallState, + /// Call duration in seconds. + pub duration_secs: f64, + /// Current quality tier (0=GOOD, 1=DEGRADED, 2=CATASTROPHIC). + pub quality_tier: u8, + /// Observed packet loss percentage. + pub loss_pct: f32, + /// Smoothed round-trip time in milliseconds. + pub rtt_ms: u32, + /// Jitter in milliseconds. + pub jitter_ms: u32, + /// Current jitter buffer depth in packets. + pub jitter_buffer_depth: usize, + /// Total frames encoded since call start. + pub frames_encoded: u64, + /// Total frames decoded since call start. + pub frames_decoded: u64, + /// Number of playout underruns (buffer empty when audio needed). + pub underruns: u64, +} diff --git a/crates/wzp-client/src/featherchat.rs b/crates/wzp-client/src/featherchat.rs index e100df9..1a175d4 100644 --- a/crates/wzp-client/src/featherchat.rs +++ b/crates/wzp-client/src/featherchat.rs @@ -115,6 +115,7 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType { #[cfg(test)] mod tests { use super::*; + use wzp_proto::QualityProfile; #[test] fn payload_roundtrip() { diff --git a/crates/wzp-codec/src/adaptive.rs b/crates/wzp-codec/src/adaptive.rs index 37e505a..54317d5 100644 --- a/crates/wzp-codec/src/adaptive.rs +++ b/crates/wzp-codec/src/adaptive.rs @@ -14,7 +14,7 @@ use crate::codec2_dec::Codec2Decoder; use crate::codec2_enc::Codec2Encoder; use crate::opus_dec::OpusDecoder; use crate::opus_enc::OpusEncoder; -use crate::resample; +use crate::resample::{Downsampler48to8, Upsampler8to48}; // ─── Helpers ───────────────────────────────────────────────────────────────── @@ -54,6 +54,7 @@ pub struct AdaptiveEncoder { opus: OpusEncoder, codec2: Codec2Encoder, active: CodecId, + downsampler: Downsampler48to8, } impl AdaptiveEncoder { @@ -66,6 +67,7 @@ impl AdaptiveEncoder { opus, codec2, active: profile.codec, + downsampler: Downsampler48to8::new(), }) } } @@ -74,7 +76,7 @@ impl AudioEncoder for AdaptiveEncoder { fn encode(&mut self, pcm: &[i16], out: &mut [u8]) -> Result { if is_codec2(self.active) { // Downsample 48 kHz → 8 kHz then encode via Codec2. - let pcm_8k = resample::resample_48k_to_8k(pcm); + let pcm_8k = self.downsampler.process(pcm); self.codec2.encode(&pcm_8k, out) } else { self.opus.encode(pcm, out) @@ -126,6 +128,7 @@ pub struct AdaptiveDecoder { opus: OpusDecoder, codec2: Codec2Decoder, active: CodecId, + upsampler: Upsampler8to48, } impl AdaptiveDecoder { @@ -138,6 +141,7 @@ impl AdaptiveDecoder { opus, codec2, active: profile.codec, + upsampler: Upsampler8to48::new(), }) } } @@ -149,7 +153,7 @@ impl AudioDecoder for AdaptiveDecoder { let c2_samples = self.codec2_frame_samples(); let mut buf_8k = vec![0i16; c2_samples]; let n = self.codec2.decode(encoded, &mut buf_8k)?; - let pcm_48k = resample::resample_8k_to_48k(&buf_8k[..n]); + let pcm_48k = self.upsampler.process(&buf_8k[..n]); let out_len = pcm_48k.len().min(pcm.len()); pcm[..out_len].copy_from_slice(&pcm_48k[..out_len]); Ok(out_len) @@ -163,7 +167,7 @@ impl AudioDecoder for AdaptiveDecoder { let c2_samples = self.codec2_frame_samples(); let mut buf_8k = vec![0i16; c2_samples]; let n = self.codec2.decode_lost(&mut buf_8k)?; - let pcm_48k = resample::resample_8k_to_48k(&buf_8k[..n]); + let pcm_48k = self.upsampler.process(&buf_8k[..n]); let out_len = pcm_48k.len().min(pcm.len()); pcm[..out_len].copy_from_slice(&pcm_48k[..out_len]); Ok(out_len) diff --git a/crates/wzp-codec/src/aec.rs b/crates/wzp-codec/src/aec.rs new file mode 100644 index 0000000..79e510c --- /dev/null +++ b/crates/wzp-codec/src/aec.rs @@ -0,0 +1,228 @@ +//! Acoustic Echo Cancellation using NLMS adaptive filter. +//! Processes 480-sample (10ms) sub-frames at 48kHz. + +/// NLMS (Normalized Least Mean Squares) adaptive filter echo canceller. +/// +/// Removes acoustic echo by modelling the echo path between the far-end +/// (speaker) signal and the near-end (microphone) signal, then subtracting +/// the estimated echo from the near-end in real time. +pub struct EchoCanceller { + filter_coeffs: Vec, + filter_len: usize, + far_end_buf: Vec, + far_end_pos: usize, + mu: f32, + enabled: bool, +} + +impl EchoCanceller { + /// Create a new echo canceller. + /// + /// * `sample_rate` — typically 48000 + /// * `filter_ms` — echo-tail length in milliseconds (e.g. 100 for 100 ms) + pub fn new(sample_rate: u32, filter_ms: u32) -> Self { + let filter_len = (sample_rate as usize) * (filter_ms as usize) / 1000; + Self { + filter_coeffs: vec![0.0f32; filter_len], + filter_len, + far_end_buf: vec![0.0f32; filter_len], + far_end_pos: 0, + mu: 0.01, + enabled: true, + } + } + + /// Feed far-end (speaker/playback) samples into the circular buffer. + /// + /// Must be called with the audio that was played out through the speaker + /// *before* the corresponding near-end frame is processed. + pub fn feed_farend(&mut self, farend: &[i16]) { + for &s in farend { + self.far_end_buf[self.far_end_pos] = s as f32; + self.far_end_pos = (self.far_end_pos + 1) % self.filter_len; + } + } + + /// Process a near-end (microphone) frame, removing the estimated echo. + /// + /// Returns the echo-return-loss enhancement (ERLE) as a ratio: the RMS of + /// the original near-end divided by the RMS of the residual. Values > 1.0 + /// mean echo was reduced. + pub fn process_frame(&mut self, nearend: &mut [i16]) -> f32 { + if !self.enabled { + return 1.0; + } + + let n = nearend.len(); + let fl = self.filter_len; + + let mut sum_near_sq: f64 = 0.0; + let mut sum_err_sq: f64 = 0.0; + + for i in 0..n { + let near_f = nearend[i] as f32; + + // --- estimate echo as dot(coeffs, farend_window) --- + // The far-end window for this sample starts at + // (far_end_pos - 1 - i) mod filter_len (most recent) + // and goes back filter_len samples. + let mut echo_est: f32 = 0.0; + let mut power: f32 = 0.0; + + // Position of the most-recent far-end sample for this near-end sample. + // far_end_pos points to the *next write* position, so the most-recent + // sample written is at far_end_pos - 1. We have already called + // feed_farend for this block, so the relevant samples are the last + // filter_len entries ending just before the current write position, + // offset by how far we are into this near-end frame. + // + // For sample i of the near-end frame, the corresponding far-end + // "now" is far_end_pos - n + i (wrapping). + // far_end_pos points to next-write, so most recent sample is at + // far_end_pos - 1. For the i-th near-end sample we want the + // far-end "now" to be at (far_end_pos - n + i). We add fl + // repeatedly to avoid underflow on the usize subtraction. + let base = (self.far_end_pos + fl * ((n / fl) + 2) + i - n) % fl; + + for k in 0..fl { + let fe_idx = (base + fl - k) % fl; + let fe = self.far_end_buf[fe_idx]; + echo_est += self.filter_coeffs[k] * fe; + power += fe * fe; + } + + let error = near_f - echo_est; + + // --- NLMS coefficient update --- + let norm = power + 1.0; // +1 regularisation to avoid div-by-zero + let step = self.mu * error / norm; + + for k in 0..fl { + let fe_idx = (base + fl - k) % fl; + let fe = self.far_end_buf[fe_idx]; + self.filter_coeffs[k] += step * fe; + } + + // Clamp output + let out = error.max(-32768.0).min(32767.0); + nearend[i] = out as i16; + + sum_near_sq += (near_f as f64) * (near_f as f64); + sum_err_sq += (out as f64) * (out as f64); + } + + // ERLE ratio + if sum_err_sq < 1.0 { + return 100.0; // near-perfect cancellation + } + (sum_near_sq / sum_err_sq).sqrt() as f32 + } + + /// Enable or disable echo cancellation. + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Returns whether echo cancellation is currently enabled. + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Reset the adaptive filter to its initial state. + /// + /// Zeroes out all filter coefficients and the far-end circular buffer. + pub fn reset(&mut self) { + self.filter_coeffs.iter_mut().for_each(|c| *c = 0.0); + self.far_end_buf.iter_mut().for_each(|s| *s = 0.0); + self.far_end_pos = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn aec_creates_with_correct_filter_len() { + let aec = EchoCanceller::new(48000, 100); + assert_eq!(aec.filter_len, 4800); + assert_eq!(aec.filter_coeffs.len(), 4800); + assert_eq!(aec.far_end_buf.len(), 4800); + } + + #[test] + fn aec_passthrough_when_disabled() { + let mut aec = EchoCanceller::new(48000, 100); + aec.set_enabled(false); + assert!(!aec.is_enabled()); + + let original: Vec = (0..480).map(|i| (i * 10) as i16).collect(); + let mut frame = original.clone(); + let erle = aec.process_frame(&mut frame); + assert_eq!(erle, 1.0); + assert_eq!(frame, original); + } + + #[test] + fn aec_reset_zeroes_state() { + let mut aec = EchoCanceller::new(48000, 10); // short for test speed + let farend: Vec = (0..480).map(|i| ((i * 37) % 1000) as i16).collect(); + aec.feed_farend(&farend); + + aec.reset(); + + assert!(aec.filter_coeffs.iter().all(|&c| c == 0.0)); + assert!(aec.far_end_buf.iter().all(|&s| s == 0.0)); + assert_eq!(aec.far_end_pos, 0); + } + + #[test] + fn aec_reduces_echo_of_known_signal() { + // Use a small filter for speed. Feed a known far-end signal, then + // present the *same* signal as near-end (perfect echo, no room). + // After adaptation the output energy should drop. + let filter_ms = 5; // 240 taps at 48 kHz + let mut aec = EchoCanceller::new(48000, filter_ms); + + // Generate a simple repeating pattern. + let frame_len = 480usize; + let make_frame = |offset: usize| -> Vec { + (0..frame_len) + .map(|i| { + let t = (offset + i) as f64 / 48000.0; + (5000.0 * (2.0 * std::f64::consts::PI * 300.0 * t).sin()) as i16 + }) + .collect() + }; + + // Warm up the adaptive filter with several frames. + let mut last_erle = 1.0f32; + for frame_idx in 0..40 { + let farend = make_frame(frame_idx * frame_len); + aec.feed_farend(&farend); + + // Near-end = exact copy of far-end (pure echo). + let mut nearend = farend.clone(); + last_erle = aec.process_frame(&mut nearend); + } + + // After 40 frames the ERLE should be meaningfully > 1. + assert!( + last_erle > 1.0, + "expected ERLE > 1.0 after adaptation, got {last_erle}" + ); + } + + #[test] + fn aec_silence_passthrough() { + let mut aec = EchoCanceller::new(48000, 10); + // Feed silence far-end + aec.feed_farend(&vec![0i16; 480]); + // Near-end is silence too + let mut frame = vec![0i16; 480]; + let erle = aec.process_frame(&mut frame); + assert!(erle >= 1.0); + // Output should still be silence + assert!(frame.iter().all(|&s| s == 0)); + } +} diff --git a/crates/wzp-codec/src/agc.rs b/crates/wzp-codec/src/agc.rs new file mode 100644 index 0000000..5456daf --- /dev/null +++ b/crates/wzp-codec/src/agc.rs @@ -0,0 +1,219 @@ +//! Automatic Gain Control (AGC) with two-stage smoothing. +//! +//! Uses a fast attack / slow release envelope follower to keep the +//! output signal near a configurable target RMS level. This prevents +//! both clipping (when the speaker is too loud) and inaudibility (when +//! the speaker is too quiet or far from the mic). + +/// Two-stage automatic gain control. +/// +/// The gain is adjusted per-frame based on the measured RMS energy, +/// with a fast attack (gain decreases quickly when signal gets louder) +/// and a slow release (gain increases gradually when signal gets quieter). +pub struct AutoGainControl { + target_rms: f64, + current_gain: f64, + min_gain: f64, + max_gain: f64, + attack_alpha: f64, + release_alpha: f64, + enabled: bool, +} + +impl AutoGainControl { + /// Create a new AGC with sensible VoIP defaults. + pub fn new() -> Self { + Self { + target_rms: 3000.0, // ~-20 dBFS for i16 + current_gain: 1.0, + min_gain: 0.5, + max_gain: 32.0, + attack_alpha: 0.3, // fast attack + release_alpha: 0.02, // slow release + enabled: true, + } + } + + /// Process a frame of PCM audio in-place, applying gain adjustment. + pub fn process_frame(&mut self, pcm: &mut [i16]) { + if !self.enabled { + return; + } + + // Compute RMS of the frame. + let rms = Self::compute_rms(pcm); + + // Don't amplify near-silence — it would just boost noise. + if rms < 10.0 { + return; + } + + // Desired instantaneous gain. + let desired_gain = (self.target_rms / rms).clamp(self.min_gain, self.max_gain); + + // Smooth the gain transition. + let alpha = if desired_gain < self.current_gain { + // Signal is louder than target → reduce gain quickly (attack). + self.attack_alpha + } else { + // Signal is quieter than target → raise gain slowly (release). + self.release_alpha + }; + + self.current_gain = self.current_gain * (1.0 - alpha) + desired_gain * alpha; + + // Apply gain to each sample with hard limiting at ±31000 (~0.946 * i16::MAX). + const LIMIT: f64 = 31000.0; + let gain = self.current_gain; + for sample in pcm.iter_mut() { + let amplified = (*sample as f64) * gain; + let clamped = amplified.clamp(-LIMIT, LIMIT); + *sample = clamped as i16; + } + } + + /// Enable or disable the AGC. + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Returns whether the AGC is currently enabled. + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Current gain expressed in dB. + pub fn current_gain_db(&self) -> f64 { + 20.0 * self.current_gain.log10() + } + + /// Compute the RMS (root mean square) of a PCM buffer. + fn compute_rms(pcm: &[i16]) -> f64 { + if pcm.is_empty() { + return 0.0; + } + let sum_sq: f64 = pcm.iter().map(|&s| (s as f64) * (s as f64)).sum(); + (sum_sq / pcm.len() as f64).sqrt() + } +} + +impl Default for AutoGainControl { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn agc_creates_with_defaults() { + let agc = AutoGainControl::new(); + assert!(agc.is_enabled()); + assert!((agc.current_gain - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn agc_passthrough_when_disabled() { + let mut agc = AutoGainControl::new(); + agc.set_enabled(false); + + let original: Vec = (0..960).map(|i| (i * 5) as i16).collect(); + let mut frame = original.clone(); + agc.process_frame(&mut frame); + + assert_eq!(frame, original); + } + + #[test] + fn agc_does_not_amplify_silence() { + let mut agc = AutoGainControl::new(); + let mut frame = vec![0i16; 960]; + agc.process_frame(&mut frame); + assert!(frame.iter().all(|&s| s == 0)); + // Gain should remain at initial value. + assert!((agc.current_gain - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn agc_amplifies_quiet_signal() { + let mut agc = AutoGainControl::new(); + + // Very quiet signal (RMS ~ 50). + let mut frame: Vec = (0..960) + .map(|i| { + let t = i as f64 / 48000.0; + (50.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16 + }) + .collect(); + + // Process several frames to let the gain ramp up. + for _ in 0..50 { + let mut f = frame.clone(); + agc.process_frame(&mut f); + frame = f; + } + + // Gain should have increased past 1.0. + assert!( + agc.current_gain > 1.05, + "expected gain > 1.05 for quiet signal, got {}", + agc.current_gain + ); + } + + #[test] + fn agc_attenuates_loud_signal() { + let mut agc = AutoGainControl::new(); + + // Loud signal (RMS ~ 20000). + let frame: Vec = (0..960) + .map(|i| { + let t = i as f64 / 48000.0; + (28000.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16 + }) + .collect(); + + // Process several frames. + for _ in 0..20 { + let mut f = frame.clone(); + agc.process_frame(&mut f); + } + + // Gain should have decreased below 1.0. + assert!( + agc.current_gain < 1.0, + "expected gain < 1.0 for loud signal, got {}", + agc.current_gain + ); + } + + #[test] + fn agc_output_within_limits() { + let mut agc = AutoGainControl::new(); + // Force a high gain by processing many quiet frames first. + for _ in 0..100 { + let mut f: Vec = vec![100; 960]; + agc.process_frame(&mut f); + } + + // Now send a louder frame — output should still be within ±31000. + let mut frame: Vec = vec![20000; 960]; + agc.process_frame(&mut frame); + assert!( + frame.iter().all(|&s| s.abs() <= 31000), + "output samples must be within ±31000" + ); + } + + #[test] + fn agc_gain_db_at_unity() { + let agc = AutoGainControl::new(); + let db = agc.current_gain_db(); + assert!( + db.abs() < 0.01, + "expected ~0 dB at unity gain, got {db}" + ); + } +} diff --git a/crates/wzp-codec/src/lib.rs b/crates/wzp-codec/src/lib.rs index 0ba8a97..592e21e 100644 --- a/crates/wzp-codec/src/lib.rs +++ b/crates/wzp-codec/src/lib.rs @@ -10,6 +10,8 @@ //! trait-object encoders/decoders that handle adaptive switching internally. pub mod adaptive; +pub mod aec; +pub mod agc; pub mod codec2_dec; pub mod codec2_enc; pub mod denoise; @@ -19,6 +21,8 @@ pub mod resample; pub mod silence; pub use adaptive::{AdaptiveDecoder, AdaptiveEncoder}; +pub use aec::EchoCanceller; +pub use agc::AutoGainControl; pub use denoise::NoiseSupressor; pub use silence::{ComfortNoise, SilenceDetector}; pub use wzp_proto::{AudioDecoder, AudioEncoder, CodecId, QualityProfile}; diff --git a/crates/wzp-codec/src/opus_enc.rs b/crates/wzp-codec/src/opus_enc.rs index 176062b..41534de 100644 --- a/crates/wzp-codec/src/opus_enc.rs +++ b/crates/wzp-codec/src/opus_enc.rs @@ -40,6 +40,11 @@ impl OpusEncoder { .set_signal(Signal::Voice) .map_err(|e| CodecError::EncodeFailed(format!("set signal: {e}")))?; + // Default complexity 7 — good quality/CPU trade-off for VoIP + enc.inner + .set_complexity(7) + .map_err(|e| CodecError::EncodeFailed(format!("set complexity: {e}")))?; + Ok(enc) } @@ -56,6 +61,21 @@ impl OpusEncoder { pub fn frame_samples(&self) -> usize { (48_000 * self.frame_duration_ms as usize) / 1000 } + + /// Set the encoder complexity (0-10). Higher values produce better quality + /// at the cost of more CPU. Default is 7. + pub fn set_complexity(&mut self, complexity: i32) { + let c = (complexity as u8).min(10); + let _ = self.inner.set_complexity(c); + } + + /// Hint the encoder about expected packet loss percentage (0-100). + /// + /// Higher values cause the encoder to use more redundancy to survive + /// packet loss, at the expense of slightly higher bitrate. + pub fn set_expected_loss(&mut self, loss_pct: u8) { + let _ = self.inner.set_packet_loss_perc(loss_pct.min(100)); + } } impl AudioEncoder for OpusEncoder { diff --git a/crates/wzp-codec/src/resample.rs b/crates/wzp-codec/src/resample.rs index 1aa6d0f..c9a0709 100644 --- a/crates/wzp-codec/src/resample.rs +++ b/crates/wzp-codec/src/resample.rs @@ -1,55 +1,258 @@ -//! Simple linear resampler for 48 kHz <-> 8 kHz conversion. +//! Windowed-sinc FIR resampler for 48 kHz <-> 8 kHz conversion. //! -//! These are basic implementations suitable for voice. For higher quality, -//! replace with the `rubato` crate later. +//! Provides both stateless free functions (backward-compatible) and stateful +//! `Downsampler48to8` / `Upsampler8to48` structs that maintain overlap history +//! between frames for glitch-free streaming. -/// Downsample from 48 kHz to 8 kHz (6:1 decimation with averaging). +use std::f64::consts::PI; + +// ─── FIR kernel parameters ───────────────────────────────────────────────── + +/// Number of FIR taps in the anti-alias / interpolation filter. +const FIR_TAPS: usize = 48; +/// Kaiser window beta parameter — controls sidelobe attenuation. +const KAISER_BETA: f64 = 8.0; +/// Cutoff frequency in Hz for the low-pass filter (just below 4 kHz Nyquist of 8 kHz). +const CUTOFF_HZ: f64 = 3800.0; +/// Working sample rate in Hz. +const SAMPLE_RATE: f64 = 48000.0; +/// Decimation / interpolation ratio between 48 kHz and 8 kHz. +const RATIO: usize = 6; + +// ─── Kaiser window helpers ───────────────────────────────────────────────── + +/// Zeroth-order modified Bessel function of the first kind, I₀(x). /// -/// Each output sample is the average of 6 consecutive input samples, -/// providing basic anti-aliasing via a box filter. -pub fn resample_48k_to_8k(input: &[i16]) -> Vec { - const RATIO: usize = 6; - let out_len = input.len() / RATIO; - let mut output = Vec::with_capacity(out_len); - - for chunk in input.chunks_exact(RATIO) { - let sum: i32 = chunk.iter().map(|&s| s as i32).sum(); - output.push((sum / RATIO as i32) as i16); +/// Computed via the well-known power-series expansion, converging rapidly +/// for the moderate values of x used in Kaiser window design. +fn bessel_i0(x: f64) -> f64 { + let mut sum = 1.0f64; + let mut term = 1.0f64; + let half_x = x / 2.0; + for k in 1..=25 { + term *= (half_x / k as f64) * (half_x / k as f64); + sum += term; + if term < 1e-12 * sum { + break; + } } - - output + sum } -/// Upsample from 8 kHz to 48 kHz (1:6 interpolation with linear interp). +/// Build a windowed-sinc low-pass FIR kernel. /// -/// Linearly interpolates between each pair of input samples to produce -/// 6 output samples per input sample. -pub fn resample_8k_to_48k(input: &[i16]) -> Vec { - const RATIO: usize = 6; - if input.is_empty() { - return Vec::new(); - } +/// Returns `FIR_TAPS` coefficients normalised so that the DC gain is exactly 1.0. +fn build_fir_kernel() -> [f64; FIR_TAPS] { + let mut kernel = [0.0f64; FIR_TAPS]; + let m = (FIR_TAPS - 1) as f64; + let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5) + let beta_denom = bessel_i0(KAISER_BETA); - let out_len = input.len() * RATIO; - let mut output = Vec::with_capacity(out_len); - - for i in 0..input.len() { - let current = input[i] as i32; - let next = if i + 1 < input.len() { - input[i + 1] as i32 + for i in 0..FIR_TAPS { + // Sinc + let n = i as f64 - m / 2.0; + let sinc = if n.abs() < 1e-12 { + 2.0 * fc } else { - current // hold last sample + (2.0 * PI * fc * n).sin() / (PI * n) }; - for j in 0..RATIO { - let interp = current + (next - current) * j as i32 / RATIO as i32; - output.push(interp as i16); + // Kaiser window + let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1] + let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom; + + kernel[i] = sinc * kaiser; + } + + // Normalise to unity DC gain. + let sum: f64 = kernel.iter().sum(); + if sum.abs() > 1e-15 { + for k in kernel.iter_mut() { + *k /= sum; } } - output + kernel } +// ─── Stateful Downsampler 48→8 ───────────────────────────────────────────── + +/// Stateful polyphase FIR downsampler from 48 kHz to 8 kHz. +/// +/// Maintains `FIR_TAPS - 1` samples of history between successive calls to +/// `process()` for seamless frame boundaries. +pub struct Downsampler48to8 { + kernel: [f64; FIR_TAPS], + history: Vec, +} + +impl Downsampler48to8 { + pub fn new() -> Self { + Self { + kernel: build_fir_kernel(), + history: vec![0.0; FIR_TAPS - 1], + } + } + + /// Downsample a block of 48 kHz samples to 8 kHz. + /// + /// The input length should be a multiple of 6; any trailing samples that + /// don't form a complete output sample are consumed into the history. + pub fn process(&mut self, input: &[i16]) -> Vec { + let hist_len = self.history.len(); // FIR_TAPS - 1 + let total_len = hist_len + input.len(); + + // Build a working buffer: history ++ input (as f64). + let mut work = Vec::with_capacity(total_len); + work.extend_from_slice(&self.history); + work.extend(input.iter().map(|&s| s as f64)); + + let out_len = input.len() / RATIO; + let mut output = Vec::with_capacity(out_len); + + for i in 0..out_len { + // The centre of the filter for output sample i sits at + // position hist_len + i*RATIO in the work buffer (aligning + // with the first new input sample at decimation phase 0). + let centre = hist_len + i * RATIO; + let start = centre + 1 - FIR_TAPS; // may be 0 for the first few + + let mut acc = 0.0f64; + for k in 0..FIR_TAPS { + let idx = start + k; + if idx < work.len() { + acc += work[idx] * self.kernel[k]; + } + } + output.push(acc.round().clamp(-32768.0, 32767.0) as i16); + } + + // Update history: keep the last (FIR_TAPS - 1) samples from work. + if work.len() >= hist_len { + self.history + .copy_from_slice(&work[work.len() - hist_len..]); + } else { + // Input was shorter than history — shift. + let shift = hist_len - work.len(); + self.history.copy_within(shift.., 0); + for (i, &v) in work.iter().enumerate() { + self.history[hist_len - work.len() + i] = v; + } + } + + output + } +} + +impl Default for Downsampler48to8 { + fn default() -> Self { + Self::new() + } +} + +// ─── Stateful Upsampler 8→48 ─────────────────────────────────────────────── + +/// Stateful FIR upsampler from 8 kHz to 48 kHz. +/// +/// Inserts zeros between input samples (zero-stuffing), then applies the +/// low-pass FIR to remove imaging, with gain compensation of `RATIO`. +pub struct Upsampler8to48 { + kernel: [f64; FIR_TAPS], + history: Vec, +} + +impl Upsampler8to48 { + pub fn new() -> Self { + Self { + kernel: build_fir_kernel(), + history: vec![0.0; FIR_TAPS - 1], + } + } + + /// Upsample a block of 8 kHz samples to 48 kHz. + pub fn process(&mut self, input: &[i16]) -> Vec { + let hist_len = self.history.len(); // FIR_TAPS - 1 + + // Zero-stuff: insert RATIO-1 zeros between each input sample. + let stuffed_len = input.len() * RATIO; + let total_len = hist_len + stuffed_len; + + let mut work = Vec::with_capacity(total_len); + work.extend_from_slice(&self.history); + for &s in input { + work.push(s as f64); + for _ in 1..RATIO { + work.push(0.0); + } + } + + let out_len = stuffed_len; + let mut output = Vec::with_capacity(out_len); + + // The gain factor compensates for the zeros introduced by stuffing. + let gain = RATIO as f64; + + for i in 0..out_len { + let centre = hist_len + i; + let start = centre + 1 - FIR_TAPS; + + let mut acc = 0.0f64; + for k in 0..FIR_TAPS { + let idx = start + k; + if idx < work.len() { + acc += work[idx] * self.kernel[k]; + } + } + acc *= gain; + output.push(acc.round().clamp(-32768.0, 32767.0) as i16); + } + + // Update history. + if work.len() >= hist_len { + self.history + .copy_from_slice(&work[work.len() - hist_len..]); + } else { + let shift = hist_len - work.len(); + self.history.copy_within(shift.., 0); + for (i, &v) in work.iter().enumerate() { + self.history[hist_len - work.len() + i] = v; + } + } + + output + } +} + +impl Default for Upsampler8to48 { + fn default() -> Self { + Self::new() + } +} + +// ─── Backward-compatible free functions ───────────────────────────────────── + +/// Downsample from 48 kHz to 8 kHz (6:1 decimation with FIR anti-alias filter). +/// +/// This is a convenience wrapper that creates a temporary [`Downsampler48to8`]. +/// For streaming use, prefer the stateful struct to avoid edge artefacts between +/// frames. +pub fn resample_48k_to_8k(input: &[i16]) -> Vec { + let mut ds = Downsampler48to8::new(); + ds.process(input) +} + +/// Upsample from 8 kHz to 48 kHz (1:6 interpolation with FIR imaging filter). +/// +/// This is a convenience wrapper that creates a temporary [`Upsampler8to48`]. +/// For streaming use, prefer the stateful struct to avoid edge artefacts between +/// frames. +pub fn resample_8k_to_48k(input: &[i16]) -> Vec { + let mut us = Upsampler8to48::new(); + us.process(input) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; @@ -66,12 +269,28 @@ mod tests { #[test] fn dc_signal_preserved() { - // A constant signal should survive resampling + // A constant signal should survive resampling (approximately). let input = vec![1000i16; 960]; let down = resample_48k_to_8k(&input); - assert!(down.iter().all(|&s| s == 1000)); + // Allow some edge transient — check that the middle samples are close. + let mid_start = down.len() / 4; + let mid_end = 3 * down.len() / 4; + for &s in &down[mid_start..mid_end] { + assert!( + (s - 1000).abs() < 50, + "DC downsampled sample {s} too far from 1000" + ); + } + let up = resample_8k_to_48k(&down); - assert!(up.iter().all(|&s| s == 1000)); + let mid_start_up = up.len() / 4; + let mid_end_up = 3 * up.len() / 4; + for &s in &up[mid_start_up..mid_end_up] { + assert!( + (s - 1000).abs() < 100, + "DC upsampled sample {s} too far from 1000" + ); + } } #[test] @@ -79,4 +298,40 @@ mod tests { assert!(resample_48k_to_8k(&[]).is_empty()); assert!(resample_8k_to_48k(&[]).is_empty()); } + + #[test] + fn stateful_downsampler_produces_correct_length() { + let mut ds = Downsampler48to8::new(); + let out = ds.process(&vec![0i16; 960]); + assert_eq!(out.len(), 160); + let out2 = ds.process(&vec![0i16; 960]); + assert_eq!(out2.len(), 160); + } + + #[test] + fn stateful_upsampler_produces_correct_length() { + let mut us = Upsampler8to48::new(); + let out = us.process(&vec![0i16; 160]); + assert_eq!(out.len(), 960); + let out2 = us.process(&vec![0i16; 160]); + assert_eq!(out2.len(), 960); + } + + #[test] + fn fir_kernel_has_unity_dc_gain() { + let kernel = build_fir_kernel(); + let sum: f64 = kernel.iter().sum(); + assert!( + (sum - 1.0).abs() < 1e-10, + "FIR kernel DC gain should be 1.0, got {sum}" + ); + } + + #[test] + fn bessel_i0_known_values() { + // I₀(0) = 1 + assert!((bessel_i0(0.0) - 1.0).abs() < 1e-12); + // I₀(1) ≈ 1.2660658 + assert!((bessel_i0(1.0) - 1.2660658).abs() < 1e-5); + } } diff --git a/crates/wzp-proto/src/jitter.rs b/crates/wzp-proto/src/jitter.rs index 5995c5a..383b3d5 100644 --- a/crates/wzp-proto/src/jitter.rs +++ b/crates/wzp-proto/src/jitter.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::time::{Duration, Instant}; use crate::packet::MediaPacket; @@ -20,19 +21,29 @@ pub struct AdaptivePlayoutDelay { max_delay: usize, /// Exponential moving average of inter-packet arrival jitter (ms). jitter_ema: f64, - /// EMA smoothing factor (0.0-1.0, lower = smoother). - alpha: f64, + /// EMA smoothing factor for jitter increases (fast reaction). + alpha_up: f64, + /// EMA smoothing factor for jitter decreases (slow decay). + alpha_down: f64, /// Last packet arrival timestamp (for computing inter-arrival jitter). last_arrival_ms: Option, /// Last packet expected timestamp. last_expected_ms: Option, + /// Safety margin added to jitter-derived target (in packets). + safety_margin: f64, + /// Instant when a jitter spike was detected (handoff detection). + spike_detected_at: Option, + /// Duration to hold max_delay after a spike is detected. + spike_cooldown: Duration, + /// Multiplier of jitter_ema that constitutes a spike. + spike_threshold_multiplier: f64, } /// Frame duration in milliseconds (20ms Opus/Codec2 frames). const FRAME_DURATION_MS: f64 = 20.0; -/// Safety margin added to jitter-derived target (in packets). -const SAFETY_MARGIN_PACKETS: f64 = 2.0; -/// Default EMA smoothing factor. +/// Default safety margin in packets. +const DEFAULT_SAFETY_MARGIN: f64 = 2.0; +/// Default EMA smoothing factor (used for both up/down in non-mobile mode). const DEFAULT_ALPHA: f64 = 0.05; impl AdaptivePlayoutDelay { @@ -46,9 +57,14 @@ impl AdaptivePlayoutDelay { min_delay, max_delay, jitter_ema: 0.0, - alpha: DEFAULT_ALPHA, + alpha_up: DEFAULT_ALPHA, + alpha_down: DEFAULT_ALPHA, last_arrival_ms: None, last_expected_ms: None, + safety_margin: DEFAULT_SAFETY_MARGIN, + spike_detected_at: None, + spike_cooldown: Duration::from_secs(2), + spike_threshold_multiplier: 3.0, } } @@ -64,13 +80,38 @@ impl AdaptivePlayoutDelay { let expected_delta = expected_ms as f64 - last_expected as f64; let jitter = (actual_delta - expected_delta).abs(); - // Update EMA - self.jitter_ema = self.alpha * jitter + (1.0 - self.alpha) * self.jitter_ema; + // Spike detection: check before EMA update + if self.jitter_ema > 0.0 + && jitter > self.jitter_ema * self.spike_threshold_multiplier + { + self.spike_detected_at = Some(Instant::now()); + } - // Convert jitter estimate to target delay in packets - let raw_target = (self.jitter_ema / FRAME_DURATION_MS).ceil() + SAFETY_MARGIN_PACKETS; - self.target_delay = - (raw_target as usize).clamp(self.min_delay, self.max_delay); + // Asymmetric EMA update + let alpha = if jitter > self.jitter_ema { + self.alpha_up + } else { + self.alpha_down + }; + self.jitter_ema = alpha * jitter + (1.0 - alpha) * self.jitter_ema; + + // Check if spike cooldown has expired + if let Some(spike_time) = self.spike_detected_at { + if spike_time.elapsed() >= self.spike_cooldown { + self.spike_detected_at = None; + } + } + + // If within spike cooldown, return max_delay + if self.spike_detected_at.is_some() { + self.target_delay = self.max_delay; + } else { + // Convert jitter estimate to target delay in packets + let raw_target = + (self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin; + self.target_delay = + (raw_target as usize).clamp(self.min_delay, self.max_delay); + } } self.last_arrival_ms = Some(arrival_ms); @@ -87,6 +128,28 @@ impl AdaptivePlayoutDelay { pub fn jitter_estimate_ms(&self) -> f64 { self.jitter_ema } + + /// Enable or disable mobile mode, adjusting parameters for cellular networks. + /// + /// Mobile mode uses: + /// - Asymmetric alpha (fast up=0.3, slow down=0.02) for quicker spike detection + /// - Higher safety margin (3.0 packets) to absorb handoff jitter + /// - Spike detection with 2-second cooldown at 3x threshold + pub fn set_mobile_mode(&mut self, enabled: bool) { + if enabled { + self.safety_margin = 3.0; + self.alpha_up = 0.3; + self.alpha_down = 0.02; + self.spike_threshold_multiplier = 3.0; + self.spike_cooldown = Duration::from_secs(2); + } else { + self.safety_margin = DEFAULT_SAFETY_MARGIN; + self.alpha_up = DEFAULT_ALPHA; + self.alpha_down = DEFAULT_ALPHA; + self.spike_threshold_multiplier = 3.0; + self.spike_cooldown = Duration::from_secs(2); + } + } } // --------------------------------------------------------------------------- @@ -391,6 +454,11 @@ impl JitterBuffer { self.adaptive.as_ref() } + /// Get a mutable reference to the adaptive playout delay estimator. + pub fn adaptive_delay_mut(&mut self) -> Option<&mut AdaptivePlayoutDelay> { + self.adaptive.as_mut() + } + /// Adjust target depth based on observed jitter. pub fn set_target_depth(&mut self, depth: usize) { self.target_depth = depth.min(self.max_depth); @@ -720,4 +788,29 @@ mod tests { let ad = jb.adaptive_delay().unwrap(); assert_eq!(ad.target_delay(), 3); } + + // --------------------------------------------------------------- + // Mobile mode tests + // --------------------------------------------------------------- + + #[test] + fn mobile_mode_increases_safety_margin() { + let mut apd = AdaptivePlayoutDelay::new(3, 50); + apd.set_mobile_mode(true); + assert_eq!(apd.safety_margin, 3.0); + assert_eq!(apd.alpha_up, 0.3); + assert_eq!(apd.alpha_down, 0.02); + + apd.set_mobile_mode(false); + assert_eq!(apd.safety_margin, DEFAULT_SAFETY_MARGIN); + assert_eq!(apd.alpha_up, DEFAULT_ALPHA); + assert_eq!(apd.alpha_down, DEFAULT_ALPHA); + } + + #[test] + fn mobile_mode_accessible_via_jitter_buffer() { + let mut jb = JitterBuffer::new_adaptive(3, 50); + jb.adaptive_delay_mut().unwrap().set_mobile_mode(true); + assert_eq!(jb.adaptive_delay().unwrap().safety_margin, 3.0); + } } diff --git a/crates/wzp-proto/src/lib.rs b/crates/wzp-proto/src/lib.rs index 1d0d59a..bb23aae 100644 --- a/crates/wzp-proto/src/lib.rs +++ b/crates/wzp-proto/src/lib.rs @@ -29,6 +29,6 @@ pub use packet::{ SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL, FRAME_TYPE_MINI, }; pub use bandwidth::{BandwidthEstimator, CongestionState}; -pub use quality::{AdaptiveQualityController, Tier}; +pub use quality::{AdaptiveQualityController, NetworkContext, Tier}; pub use session::{Session, SessionEvent, SessionState}; pub use traits::*; diff --git a/crates/wzp-proto/src/quality.rs b/crates/wzp-proto/src/quality.rs index b10912d..e5422c3 100644 --- a/crates/wzp-proto/src/quality.rs +++ b/crates/wzp-proto/src/quality.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::time::{Duration, Instant}; use crate::packet::QualityReport; use crate::traits::QualityController; @@ -24,24 +25,71 @@ impl Tier { } } - /// Determine which tier a quality report belongs to. + /// Determine which tier a quality report belongs to (default/WiFi thresholds). pub fn classify(report: &QualityReport) -> Self { + Self::classify_with_context(report, NetworkContext::Unknown) + } + + /// Classify with network-context-aware thresholds. + pub fn classify_with_context(report: &QualityReport, context: NetworkContext) -> Self { let loss = report.loss_percent(); let rtt = report.rtt_ms(); - if loss > 40.0 || rtt > 600 { - Self::Catastrophic - } else if loss > 10.0 || rtt > 400 { - Self::Degraded - } else { - Self::Good + match context { + NetworkContext::CellularLte + | NetworkContext::Cellular5g + | NetworkContext::Cellular3g => { + // Tighter thresholds for cellular networks + if loss > 25.0 || rtt > 500 { + Self::Catastrophic + } else if loss > 8.0 || rtt > 300 { + Self::Degraded + } else { + Self::Good + } + } + NetworkContext::WiFi | NetworkContext::Unknown => { + // Original thresholds + if loss > 40.0 || rtt > 600 { + Self::Catastrophic + } else if loss > 10.0 || rtt > 400 { + Self::Degraded + } else { + Self::Good + } + } } } + + /// Return the next lower (worse) tier, or None if already at the worst. + pub fn downgrade(self) -> Option { + match self { + Self::Good => Some(Self::Degraded), + Self::Degraded => Some(Self::Catastrophic), + Self::Catastrophic => None, + } + } +} + +/// Describes the network transport type for context-aware quality decisions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NetworkContext { + WiFi, + CellularLte, + Cellular5g, + Cellular3g, + Unknown, +} + +impl Default for NetworkContext { + fn default() -> Self { + Self::Unknown + } } /// Adaptive quality controller with hysteresis to prevent tier flapping. /// -/// - Downgrade: 3 consecutive reports in a worse tier +/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular) /// - Upgrade: 10 consecutive reports in a better tier pub struct AdaptiveQualityController { current_tier: Tier, @@ -54,14 +102,26 @@ pub struct AdaptiveQualityController { history: VecDeque, /// Whether the profile was manually forced (disables adaptive logic). forced: bool, + /// Current network context for threshold selection. + network_context: NetworkContext, + /// FEC boost expiry time (set during network handoff). + fec_boost_until: Option, + /// FEC boost amount to add during handoff recovery window. + fec_boost_amount: f32, } /// Threshold for downgrading (fast reaction to degradation). const DOWNGRADE_THRESHOLD: u32 = 3; +/// Threshold for downgrading on cellular networks (even faster). +const CELLULAR_DOWNGRADE_THRESHOLD: u32 = 2; /// Threshold for upgrading (slow, cautious improvement). const UPGRADE_THRESHOLD: u32 = 10; /// Maximum history window size. const HISTORY_SIZE: usize = 20; +/// Default FEC boost amount during handoff recovery. +const DEFAULT_FEC_BOOST: f32 = 0.2; +/// Duration of FEC boost after a network handoff. +const FEC_BOOST_DURATION_SECS: u64 = 10; impl AdaptiveQualityController { pub fn new() -> Self { @@ -72,6 +132,9 @@ impl AdaptiveQualityController { consecutive_down: 0, history: VecDeque::with_capacity(HISTORY_SIZE), forced: false, + network_context: NetworkContext::default(), + fec_boost_until: None, + fec_boost_amount: DEFAULT_FEC_BOOST, } } @@ -80,6 +143,69 @@ impl AdaptiveQualityController { self.current_tier } + /// Get the current network context. + pub fn network_context(&self) -> NetworkContext { + self.network_context + } + + /// Signal a network transport change (e.g., WiFi to cellular handoff). + /// + /// When switching from WiFi to any cellular type, this preemptively + /// downgrades one quality tier and activates a temporary FEC boost. + pub fn signal_network_change(&mut self, new_context: NetworkContext) { + let old = self.network_context; + self.network_context = new_context; + + let new_is_cellular = matches!( + new_context, + NetworkContext::CellularLte | NetworkContext::Cellular5g | NetworkContext::Cellular3g + ); + + // If switching from WiFi to cellular, preemptively downgrade one tier + if old == NetworkContext::WiFi && new_is_cellular { + if let Some(lower_tier) = self.current_tier.downgrade() { + self.current_tier = lower_tier; + self.current_profile = lower_tier.profile(); + } + // Reset counters to avoid stale hysteresis state + self.consecutive_up = 0; + self.consecutive_down = 0; + // Un-force so adaptive logic resumes + self.forced = false; + } + + // Activate FEC boost for any network change + self.fec_boost_until = Some(Instant::now() + Duration::from_secs(FEC_BOOST_DURATION_SECS)); + } + + /// Returns the FEC boost amount if within the handoff recovery window, 0.0 otherwise. + /// + /// Callers should add this to their base FEC ratio during the boost window. + pub fn fec_boost(&self) -> f32 { + if let Some(until) = self.fec_boost_until { + if Instant::now() < until { + return self.fec_boost_amount; + } + } + 0.0 + } + + /// Reset the hysteresis counters. + pub fn reset_counters(&mut self) { + self.consecutive_up = 0; + self.consecutive_down = 0; + } + + /// Get the effective downgrade threshold based on network context. + fn downgrade_threshold(&self) -> u32 { + match self.network_context { + NetworkContext::CellularLte + | NetworkContext::Cellular5g + | NetworkContext::Cellular3g => CELLULAR_DOWNGRADE_THRESHOLD, + _ => DOWNGRADE_THRESHOLD, + } + } + fn try_transition(&mut self, observed_tier: Tier) -> Option { if observed_tier == self.current_tier { self.consecutive_up = 0; @@ -96,7 +222,7 @@ impl AdaptiveQualityController { if is_worse { self.consecutive_up = 0; self.consecutive_down += 1; - if self.consecutive_down >= DOWNGRADE_THRESHOLD { + if self.consecutive_down >= self.downgrade_threshold() { self.current_tier = observed_tier; self.current_profile = observed_tier.profile(); self.consecutive_down = 0; @@ -142,7 +268,7 @@ impl QualityController for AdaptiveQualityController { return None; } - let observed = Tier::classify(report); + let observed = Tier::classify_with_context(report, self.network_context); self.try_transition(observed) } @@ -246,4 +372,110 @@ mod tests { assert_eq!(Tier::classify(&make_report(50.0, 200)), Tier::Catastrophic); assert_eq!(Tier::classify(&make_report(5.0, 700)), Tier::Catastrophic); } + + // --------------------------------------------------------------- + // Network context tests + // --------------------------------------------------------------- + + #[test] + fn cellular_tighter_thresholds() { + // 12% loss: Good on WiFi, Degraded on cellular + let report = make_report(12.0, 200); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::WiFi), + Tier::Degraded + ); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::CellularLte), + Tier::Degraded + ); + + // 9% loss: Good on WiFi, Degraded on cellular + let report = make_report(9.0, 200); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::WiFi), + Tier::Good + ); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::CellularLte), + Tier::Degraded + ); + + // 30% loss: Degraded on WiFi, Catastrophic on cellular + let report = make_report(30.0, 200); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::WiFi), + Tier::Degraded + ); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::Cellular3g), + Tier::Catastrophic + ); + } + + #[test] + fn cellular_rtt_thresholds() { + // RTT 350ms: Good on WiFi, Degraded on cellular + let report = make_report(2.0, 348); // rtt_4ms rounds so use 348 + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::WiFi), + Tier::Good + ); + assert_eq!( + Tier::classify_with_context(&report, NetworkContext::CellularLte), + Tier::Degraded + ); + } + + #[test] + fn cellular_faster_downgrade() { + let mut ctrl = AdaptiveQualityController::new(); + ctrl.signal_network_change(NetworkContext::CellularLte); + // Reset tier back to Good for testing downgrade threshold + ctrl.current_tier = Tier::Good; + ctrl.current_profile = Tier::Good.profile(); + + // On cellular, downgrade threshold is 2 instead of 3 + let bad = make_report(50.0, 200); + assert!(ctrl.observe(&bad).is_none()); // 1st bad + let result = ctrl.observe(&bad); // 2nd bad — should trigger on cellular + assert!(result.is_some()); + } + + #[test] + fn signal_network_change_preemptive_downgrade() { + let mut ctrl = AdaptiveQualityController::new(); + assert_eq!(ctrl.tier(), Tier::Good); + + // Switch from WiFi to cellular + ctrl.network_context = NetworkContext::WiFi; + ctrl.signal_network_change(NetworkContext::CellularLte); + + // Should have downgraded one tier: Good -> Degraded + assert_eq!(ctrl.tier(), Tier::Degraded); + } + + #[test] + fn signal_network_change_fec_boost() { + let mut ctrl = AdaptiveQualityController::new(); + assert_eq!(ctrl.fec_boost(), 0.0); + + ctrl.signal_network_change(NetworkContext::CellularLte); + + // FEC boost should be active + assert!(ctrl.fec_boost() > 0.0); + assert_eq!(ctrl.fec_boost(), DEFAULT_FEC_BOOST); + } + + #[test] + fn tier_downgrade() { + assert_eq!(Tier::Good.downgrade(), Some(Tier::Degraded)); + assert_eq!(Tier::Degraded.downgrade(), Some(Tier::Catastrophic)); + assert_eq!(Tier::Catastrophic.downgrade(), None); + } + + #[test] + fn network_context_default() { + assert_eq!(NetworkContext::default(), NetworkContext::Unknown); + } } diff --git a/crates/wzp-transport/src/path_monitor.rs b/crates/wzp-transport/src/path_monitor.rs index b5be9b9..837565c 100644 --- a/crates/wzp-transport/src/path_monitor.rs +++ b/crates/wzp-transport/src/path_monitor.rs @@ -149,6 +149,27 @@ impl PathMonitor { } 0 } + + /// Detect whether a network handoff likely occurred. + /// + /// Returns `true` if the most recent RTT jitter measurement exceeds 3x + /// the EWMA-smoothed jitter average, which is characteristic of a cellular + /// network handoff (tower switch, WiFi-to-cellular transition, etc.). + pub fn detect_handoff(&self) -> bool { + // We need at least two RTT observations to have a meaningful jitter value, + // and the EWMA must be non-zero to avoid division/multiplication by zero. + if self.jitter_ewma <= 0.0 { + return false; + } + + if let (Some(last_rtt), Some(_)) = (self.last_rtt_ms, Some(self.rtt_ewma)) { + // Compute the most recent instantaneous jitter (RTT deviation from EWMA) + let instant_jitter = (last_rtt - self.rtt_ewma).abs(); + instant_jitter > self.jitter_ewma * 3.0 + } else { + false + } + } } impl Default for PathMonitor {