Compare commits
100 Commits
237adbbf21
...
debug/code
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
073756ed4b | ||
|
|
2fcc2d77cf | ||
|
|
f7ccb67b02 | ||
|
|
4df08eadbd | ||
|
|
6d776097c8 | ||
|
|
9f7962a6cd | ||
|
|
8c9befb15d | ||
|
|
3f869a4cd7 | ||
|
|
2263e898e5 | ||
|
|
9ab57ba037 | ||
|
|
7806d4ec04 | ||
|
|
d31b81a21d | ||
|
|
4d54b6f9e4 | ||
|
|
c268ce419a | ||
|
|
61b6e67610 | ||
|
|
dddf5d2e2d | ||
|
|
ed272d29f8 | ||
|
|
2b3bdae440 | ||
|
|
21f5b24cbf | ||
|
|
9b733010ab | ||
|
|
80d5bd7628 | ||
|
|
4a195a923a | ||
|
|
f726f8cfa4 | ||
|
|
20922455bd | ||
|
|
e468454464 | ||
|
|
d1c96cd71f | ||
|
|
1b00b5e2a4 | ||
|
|
e6564bab57 | ||
|
|
cfb48df1ef | ||
|
|
aebf9156c0 | ||
|
|
9bbaec6b35 | ||
|
|
ba29d8354f | ||
|
|
0908507a7a | ||
|
|
860c90394d | ||
|
|
dc66b60d18 | ||
|
|
a9c4260b4e | ||
|
|
7eb136fcb3 | ||
|
|
550a124972 | ||
|
|
0835c36d0f | ||
|
|
6228ab32c1 | ||
|
|
bd258f432a | ||
|
|
8bf073aa80 | ||
|
|
72e834b45e | ||
|
|
673ffd498c | ||
|
|
2d4b8eebd5 | ||
|
|
a23d9f5e41 | ||
|
|
b3e56ecbd8 | ||
|
|
2fa07286c3 | ||
|
|
bf91cf25bd | ||
|
|
81c756c076 | ||
|
|
af85a49e86 | ||
|
|
bae03365da | ||
|
|
9d9ce4706d | ||
|
|
9098e28a1f | ||
|
|
f6d51fce61 | ||
|
|
a8dd0c2f57 | ||
|
|
64566e9acb | ||
|
|
10eb19cd24 | ||
|
|
778f4dd428 | ||
|
|
622fdee51f | ||
|
|
b204213a01 | ||
|
|
e751af7e38 | ||
|
|
8d5f6fe044 | ||
|
|
780309fede | ||
|
|
73ebcdd869 | ||
|
|
e7b1c3372a | ||
|
|
26e9c55f1f | ||
|
|
aa09275015 | ||
|
|
59bf3f6587 | ||
|
|
4fb15fe7a3 | ||
|
|
e595fe6591 | ||
|
|
326aa491cc | ||
|
|
464e95a4bd | ||
|
|
fd95167705 | ||
|
|
9e7fea7633 | ||
|
|
993cf9ab7f | ||
|
|
6f4e8eb9f6 | ||
|
|
634cd40fdc | ||
|
|
6310864b0b | ||
|
|
4d2c9838c5 | ||
|
|
ab8a7f7a96 | ||
|
|
59268f0391 | ||
|
|
a833694568 | ||
|
|
6d5ee55393 | ||
|
|
0dc381e948 | ||
|
|
34cd1017c1 | ||
|
|
a64b79d953 | ||
|
|
216ebf4a25 | ||
|
|
39f6908478 | ||
|
|
3f813cd510 | ||
|
|
59a00d371b | ||
|
|
524d1145bb | ||
|
|
bf56d84ef0 | ||
|
|
59069bfba2 | ||
|
|
26dc848081 | ||
|
|
ad16ddb903 | ||
|
|
d870c9e08a | ||
|
|
616505e8a9 | ||
|
|
12cdfe6c8a | ||
|
|
97402f6e60 |
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[target.aarch64-linux-android]
|
||||
linker = "aarch64-linux-android26-clang"
|
||||
|
||||
[target.armv7-linux-androideabi]
|
||||
linker = "armv7a-linux-androideabi26-clang"
|
||||
@@ -2,187 +2,57 @@ name: Build Release Binaries
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'feat/*'
|
||||
tags:
|
||||
- 'v*'
|
||||
paths-ignore:
|
||||
- '.gitea/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
targets:
|
||||
description: 'Targets to build (comma-separated: amd64,arm64,armv7,mac-arm64)'
|
||||
required: false
|
||||
default: 'amd64'
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
# Always builds on push tags. On manual dispatch, reads inputs.
|
||||
build-amd64:
|
||||
if: >-
|
||||
github.event_name == 'push' ||
|
||||
contains(github.event.inputs.targets, 'amd64')
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: rust:1-bookworm
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt-get install -y cmake pkg-config libasound2-dev
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: cargo-amd64-${{ hashFiles('Cargo.lock') }}
|
||||
restore-keys: cargo-amd64-
|
||||
|
||||
- name: Build headless binaries
|
||||
run: cargo build --release --bin wzp-relay --bin wzp-client --bin wzp-bench --bin wzp-web
|
||||
|
||||
- name: Build audio client
|
||||
- name: Init submodules
|
||||
run: |
|
||||
cargo build --release --bin wzp-client --features audio
|
||||
cp target/release/wzp-client target/release/wzp-client-audio
|
||||
cargo build --release --bin wzp-client
|
||||
git config --global url."https://git.manko.yoga/".insteadOf "ssh://git@git.manko.yoga:222/"
|
||||
git submodule update --init --recursive
|
||||
|
||||
- name: Install Rust + dependencies
|
||||
run: |
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
apt-get update && apt-get install -y cmake pkg-config libasound2-dev ninja-build
|
||||
rustc --version
|
||||
|
||||
- name: Build relay + tools
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
cargo build --release --bin wzp-relay --bin wzp-client --bin wzp-bench --bin wzp-web
|
||||
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --lib
|
||||
|
||||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist/wzp-linux-amd64
|
||||
cp target/release/wzp-relay dist/wzp-linux-amd64/
|
||||
cp target/release/wzp-client dist/wzp-linux-amd64/
|
||||
cp target/release/wzp-client-audio dist/wzp-linux-amd64/
|
||||
cp target/release/wzp-web dist/wzp-linux-amd64/
|
||||
cp target/release/wzp-bench dist/wzp-linux-amd64/
|
||||
cp -r crates/wzp-web/static dist/wzp-linux-amd64/
|
||||
cd dist && tar czf wzp-linux-amd64.tar.gz wzp-linux-amd64/
|
||||
source "$HOME/.cargo/env"
|
||||
cargo test --workspace --lib
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wzp-linux-amd64
|
||||
path: dist/wzp-linux-amd64.tar.gz
|
||||
|
||||
build-arm64:
|
||||
if: >-
|
||||
github.event_name == 'push' ||
|
||||
contains(github.event.inputs.targets, 'arm64')
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: rust:1-bookworm
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install cross-compilation tools
|
||||
run: |
|
||||
dpkg --add-architecture arm64
|
||||
apt-get update
|
||||
apt-get install -y cmake pkg-config gcc-aarch64-linux-gnu libc6-dev-arm64-cross
|
||||
rustup target add aarch64-unknown-linux-gnu
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: cargo-arm64-${{ hashFiles('Cargo.lock') }}
|
||||
restore-keys: cargo-arm64-
|
||||
|
||||
- name: Build
|
||||
- name: Upload to rustypaste
|
||||
env:
|
||||
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
|
||||
CC_aarch64_unknown_linux_gnu: aarch64-linux-gnu-gcc
|
||||
PASTE_AUTH: ${{ secrets.PASTE_AUTH }}
|
||||
PASTE_URL: ${{ secrets.PASTE_URL }}
|
||||
run: |
|
||||
cargo build --release --target aarch64-unknown-linux-gnu \
|
||||
--bin wzp-relay --bin wzp-client --bin wzp-bench --bin wzp-web
|
||||
|
||||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist/wzp-linux-arm64
|
||||
cp target/aarch64-unknown-linux-gnu/release/wzp-relay dist/wzp-linux-arm64/
|
||||
cp target/aarch64-unknown-linux-gnu/release/wzp-client dist/wzp-linux-arm64/
|
||||
cp target/aarch64-unknown-linux-gnu/release/wzp-web dist/wzp-linux-arm64/
|
||||
cp target/aarch64-unknown-linux-gnu/release/wzp-bench dist/wzp-linux-arm64/
|
||||
cp -r crates/wzp-web/static dist/wzp-linux-arm64/
|
||||
cd dist && tar czf wzp-linux-arm64.tar.gz wzp-linux-arm64/
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wzp-linux-arm64
|
||||
path: dist/wzp-linux-arm64.tar.gz
|
||||
|
||||
build-armv7:
|
||||
if: >-
|
||||
github.event_name == 'push' ||
|
||||
contains(github.event.inputs.targets, 'armv7')
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: rust:1-bookworm
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install cross-compilation tools
|
||||
run: |
|
||||
dpkg --add-architecture armhf
|
||||
apt-get update
|
||||
apt-get install -y cmake pkg-config gcc-arm-linux-gnueabihf libc6-dev-armhf-cross
|
||||
rustup target add armv7-unknown-linux-gnueabihf
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: cargo-armv7-${{ hashFiles('Cargo.lock') }}
|
||||
restore-keys: cargo-armv7-
|
||||
|
||||
- name: Build
|
||||
env:
|
||||
CARGO_TARGET_ARMV7_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc
|
||||
CC_armv7_unknown_linux_gnueabihf: arm-linux-gnueabihf-gcc
|
||||
run: |
|
||||
cargo build --release --target armv7-unknown-linux-gnueabihf \
|
||||
--bin wzp-relay --bin wzp-client --bin wzp-bench --bin wzp-web
|
||||
|
||||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist/wzp-linux-armv7
|
||||
cp target/armv7-unknown-linux-gnueabihf/release/wzp-relay dist/wzp-linux-armv7/
|
||||
cp target/armv7-unknown-linux-gnueabihf/release/wzp-client dist/wzp-linux-armv7/
|
||||
cp target/armv7-unknown-linux-gnueabihf/release/wzp-web dist/wzp-linux-armv7/
|
||||
cp target/armv7-unknown-linux-gnueabihf/release/wzp-bench dist/wzp-linux-armv7/
|
||||
cp -r crates/wzp-web/static dist/wzp-linux-armv7/
|
||||
cd dist && tar czf wzp-linux-armv7.tar.gz wzp-linux-armv7/
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wzp-linux-armv7
|
||||
path: dist/wzp-linux-armv7.tar.gz
|
||||
|
||||
# Release job — creates a release with all artifacts when a tag is pushed
|
||||
release:
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
needs: [build-amd64]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Create release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: artifacts/**/*.tar.gz
|
||||
generate_release_notes: true
|
||||
tar czf /tmp/wzp-linux-amd64.tar.gz \
|
||||
-C target/release wzp-relay wzp-client wzp-web wzp-bench
|
||||
ls -lh /tmp/wzp-linux-amd64.tar.gz
|
||||
LINK=$(curl -sF "file=@/tmp/wzp-linux-amd64.tar.gz" \
|
||||
-H "Authorization: ${PASTE_AUTH}" \
|
||||
"https://${PASTE_URL}")
|
||||
echo "Download: ${LINK}"
|
||||
|
||||
43
.gitea/workflows/mirror-github.yml
Normal file
43
.gitea/workflows/mirror-github.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Mirror to GitHub
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'feat/*'
|
||||
- 'feature/*'
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
mirror:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Push to GitHub
|
||||
env:
|
||||
GH_SSH_KEY: ${{ secrets.GH_SSH_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "${GH_SSH_KEY}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh-keyscan github.com >> ~/.ssh/known_hosts 2>/dev/null
|
||||
|
||||
git remote add github git@github.com:manawenuz/wzp.git
|
||||
|
||||
# Push the current branch
|
||||
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||
TAG="${GITHUB_REF#refs/tags/}"
|
||||
|
||||
if [ "${GITHUB_REF}" != "${GITHUB_REF#refs/tags/}" ]; then
|
||||
echo "Pushing tag: ${TAG}"
|
||||
git push github "refs/tags/${TAG}" --force
|
||||
else
|
||||
echo "Pushing branch: ${BRANCH}"
|
||||
git push github "HEAD:refs/heads/${BRANCH}" --force
|
||||
fi
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "deps/featherchat"]
|
||||
path = deps/featherchat
|
||||
url = ssh://git@git.manko.yoga:222/manawenuz/featherChat.git
|
||||
4694
Cargo.lock
generated
4694
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
24
Cargo.toml
24
Cargo.toml
@@ -9,6 +9,8 @@ members = [
|
||||
"crates/wzp-relay",
|
||||
"crates/wzp-client",
|
||||
"crates/wzp-web",
|
||||
"crates/wzp-android",
|
||||
"desktop/src-tauri",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -51,3 +53,25 @@ wzp-codec = { path = "crates/wzp-codec" }
|
||||
wzp-fec = { path = "crates/wzp-fec" }
|
||||
wzp-crypto = { path = "crates/wzp-crypto" }
|
||||
wzp-transport = { path = "crates/wzp-transport" }
|
||||
wzp-client = { path = "crates/wzp-client" }
|
||||
|
||||
# Fast dev profile: optimized but with debug info and incremental compilation.
|
||||
# Use with: cargo run --profile dev-fast
|
||||
[profile.dev-fast]
|
||||
inherits = "dev"
|
||||
opt-level = 2
|
||||
|
||||
# Optimize heavy compute deps even in debug builds —
|
||||
# real-time audio needs < 20ms per frame, impossible unoptimized.
|
||||
[profile.dev.package.nnnoiseless]
|
||||
opt-level = 3
|
||||
[profile.dev.package.audiopus_sys]
|
||||
opt-level = 3
|
||||
[profile.dev.package.audiopus]
|
||||
opt-level = 3
|
||||
[profile.dev.package.raptorq]
|
||||
opt-level = 3
|
||||
[profile.dev.package.wzp-codec]
|
||||
opt-level = 3
|
||||
[profile.dev.package.wzp-fec]
|
||||
opt-level = 3
|
||||
|
||||
87
README.md
Normal file
87
README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# WarzonePhone
|
||||
|
||||
Custom lossy VoIP protocol built in Rust. E2E encrypted, FEC-protected, adaptive quality, designed for hostile network conditions.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Build
|
||||
cargo build --release
|
||||
|
||||
# Run relay
|
||||
./target/release/wzp-relay --listen 0.0.0.0:4433
|
||||
|
||||
# Send a test tone
|
||||
./target/release/wzp-client --send-tone 5 relay-addr:4433
|
||||
|
||||
# Web bridge (browser calls)
|
||||
./target/release/wzp-web --port 8080 --relay 127.0.0.1:4433 --tls
|
||||
# Open https://localhost:8080/room-name in two browser tabs
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
See [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md) for the full system architecture with Mermaid diagrams covering:
|
||||
|
||||
- System overview and data flow
|
||||
- Crate dependency graph (8 crates)
|
||||
- Wire formats (MediaHeader, MiniHeader, TrunkFrame, SignalMessage)
|
||||
- Cryptographic handshake (X25519 + Ed25519 + ChaCha20-Poly1305)
|
||||
- Identity model (BIP39 seed, featherChat compatible)
|
||||
- Quality profiles (GOOD/DEGRADED/CATASTROPHIC)
|
||||
- FEC protection (RaptorQ with interleaving)
|
||||
- Adaptive jitter buffer (NetEq-inspired)
|
||||
- Telemetry stack (Prometheus + Grafana)
|
||||
- Deployment topology
|
||||
|
||||
## Features
|
||||
|
||||
- **3 quality tiers**: Opus 24k (28.8 kbps) / Opus 6k (9 kbps) / Codec2 1200 (2.4 kbps)
|
||||
- **RaptorQ FEC**: Recovers from 20-100% packet loss depending on tier
|
||||
- **E2E encryption**: ChaCha20-Poly1305 with X25519 key exchange
|
||||
- **Adaptive jitter buffer**: EMA-based playout delay tracking
|
||||
- **Silence suppression**: VAD + comfort noise (~50% bandwidth savings)
|
||||
- **ML noise removal**: RNNoise (nnnoiseless pure Rust port)
|
||||
- **Mini-frames**: 67% header compression for steady-state packets
|
||||
- **Trunking**: Multiplex sessions into batched datagrams
|
||||
- **featherChat integration**: Shared BIP39 identity, token auth, call signaling
|
||||
- **Prometheus metrics**: Relay, web bridge, inter-relay probes
|
||||
- **Grafana dashboard**: Pre-built JSON with 18 panels
|
||||
|
||||
## Documentation
|
||||
|
||||
| Document | Description |
|
||||
|----------|-------------|
|
||||
| [ARCHITECTURE.md](docs/ARCHITECTURE.md) | Full system architecture with diagrams |
|
||||
| [TELEMETRY.md](docs/TELEMETRY.md) | Prometheus metrics specification |
|
||||
| [INTEGRATION_TASKS.md](docs/INTEGRATION_TASKS.md) | featherChat integration tracker |
|
||||
| [WZP-FC-SHARED-CRATES.md](docs/WZP-FC-SHARED-CRATES.md) | Shared crate strategy |
|
||||
| [grafana-dashboard.json](docs/grafana-dashboard.json) | Importable Grafana dashboard |
|
||||
|
||||
## Binaries
|
||||
|
||||
| Binary | Description |
|
||||
|--------|-------------|
|
||||
| `wzp-relay` | Relay daemon (SFU room mode, forward mode, probes) |
|
||||
| `wzp-client` | CLI client (send-tone, record, live mic, echo-test, drift-test, sweep) |
|
||||
| `wzp-web` | Browser bridge (HTTPS + WebSocket + AudioWorklet) |
|
||||
| `wzp-bench` | Component benchmarks |
|
||||
|
||||
## Linux Build
|
||||
|
||||
```bash
|
||||
./scripts/build-linux.sh --prepare # Create Hetzner VM + install deps
|
||||
./scripts/build-linux.sh --build # Build release binaries
|
||||
./scripts/build-linux.sh --transfer # Download to target/linux-x86_64/
|
||||
./scripts/build-linux.sh --destroy # Delete VM
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
cargo test --workspace # 272 tests
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
6
android/.gitignore
vendored
Normal file
6
android/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
.gradle/
|
||||
build/
|
||||
app/build/
|
||||
app/src/main/jniLibs/
|
||||
local.properties
|
||||
keystore/*.jks
|
||||
BIN
android/android/app/src/main/jniLibs/arm64-v8a/libwzp_android.so
Executable file
BIN
android/android/app/src/main/jniLibs/arm64-v8a/libwzp_android.so
Executable file
Binary file not shown.
85
android/app/build.gradle.kts
Normal file
85
android/app/build.gradle.kts
Normal file
@@ -0,0 +1,85 @@
|
||||
plugins {
|
||||
id("com.android.application")
|
||||
id("org.jetbrains.kotlin.android")
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "com.wzp.phone"
|
||||
compileSdk = 34
|
||||
|
||||
defaultConfig {
|
||||
applicationId = "com.wzp.phone"
|
||||
minSdk = 26 // AAudio requires API 26
|
||||
targetSdk = 34
|
||||
versionCode = 1
|
||||
versionName = "0.1.0"
|
||||
ndk { abiFilters += listOf("arm64-v8a") }
|
||||
}
|
||||
|
||||
signingConfigs {
|
||||
create("release") {
|
||||
storeFile = file("${project.rootDir}/keystore/wzp-release.jks")
|
||||
storePassword = "wzphone2024"
|
||||
keyAlias = "wzp-release"
|
||||
keyPassword = "wzphone2024"
|
||||
}
|
||||
getByName("debug") {
|
||||
storeFile = file("${project.rootDir}/keystore/wzp-debug.jks")
|
||||
storePassword = "android"
|
||||
keyAlias = "wzp-debug"
|
||||
keyPassword = "android"
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
debug {
|
||||
signingConfig = signingConfigs.getByName("debug")
|
||||
isDebuggable = true
|
||||
}
|
||||
release {
|
||||
signingConfig = signingConfigs.getByName("release")
|
||||
isMinifyEnabled = false
|
||||
proguardFiles(
|
||||
getDefaultProguardFile("proguard-android-optimize.txt"),
|
||||
"proguard-rules.pro"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_1_8
|
||||
targetCompatibility = JavaVersion.VERSION_1_8
|
||||
}
|
||||
|
||||
kotlinOptions {
|
||||
jvmTarget = "1.8"
|
||||
}
|
||||
|
||||
buildFeatures { compose = true }
|
||||
composeOptions { kotlinCompilerExtensionVersion = "1.5.8" }
|
||||
|
||||
ndkVersion = "26.1.10909125"
|
||||
}
|
||||
|
||||
// cargo-ndk integration: build the Rust native library for Android targets
|
||||
tasks.register<Exec>("cargoNdkBuild") {
|
||||
workingDir = file("${project.rootDir}/..")
|
||||
commandLine(
|
||||
"cargo", "ndk",
|
||||
"-t", "arm64-v8a",
|
||||
"-o", "${project.projectDir}/src/main/jniLibs",
|
||||
"build", "--release", "-p", "wzp-android"
|
||||
)
|
||||
}
|
||||
|
||||
// Skip cargo-ndk in CI/Docker — .so is pre-built into jniLibs
|
||||
// tasks.named("preBuild") { dependsOn("cargoNdkBuild") }
|
||||
|
||||
dependencies {
|
||||
implementation("androidx.core:core-ktx:1.12.0")
|
||||
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.7.0")
|
||||
implementation("androidx.activity:activity-compose:1.8.2")
|
||||
implementation(platform("androidx.compose:compose-bom:2024.01.00"))
|
||||
implementation("androidx.compose.ui:ui")
|
||||
implementation("androidx.compose.material3:material3")
|
||||
}
|
||||
9
android/app/proguard-rules.pro
vendored
Normal file
9
android/app/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# WZPhone ProGuard rules
|
||||
|
||||
# Keep JNI native methods
|
||||
-keepclasseswithmembernames class * {
|
||||
native <methods>;
|
||||
}
|
||||
|
||||
# Keep the WZP engine bridge class
|
||||
-keep class com.wzp.phone.engine.** { *; }
|
||||
43
android/app/src/main/AndroidManifest.xml
Normal file
43
android/app/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,43 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_MICROPHONE" />
|
||||
<uses-permission android:name="android.permission.WAKE_LOCK" />
|
||||
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_CONNECT" />
|
||||
<uses-permission android:name="android.permission.MODIFY_AUDIO_SETTINGS" />
|
||||
|
||||
<application
|
||||
android:name="com.wzp.WzpApplication"
|
||||
android:label="WZ Phone"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@android:style/Theme.Material.Light.NoActionBar">
|
||||
|
||||
<activity
|
||||
android:name="com.wzp.ui.call.CallActivity"
|
||||
android:exported="true"
|
||||
android:launchMode="singleTask">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
|
||||
<service
|
||||
android:name="com.wzp.service.CallService"
|
||||
android:foregroundServiceType="microphone"
|
||||
android:exported="false" />
|
||||
|
||||
<provider
|
||||
android:name="androidx.core.content.FileProvider"
|
||||
android:authorities="${applicationId}.fileprovider"
|
||||
android:exported="false"
|
||||
android:grantUriPermissions="true">
|
||||
<meta-data
|
||||
android:name="android.support.FILE_PROVIDER_PATHS"
|
||||
android:resource="@xml/file_paths" />
|
||||
</provider>
|
||||
</application>
|
||||
</manifest>
|
||||
0
android/app/src/main/java/com/wzp/.gitkeep
Normal file
0
android/app/src/main/java/com/wzp/.gitkeep
Normal file
38
android/app/src/main/java/com/wzp/WzpApplication.kt
Normal file
38
android/app/src/main/java/com/wzp/WzpApplication.kt
Normal file
@@ -0,0 +1,38 @@
|
||||
package com.wzp
|
||||
|
||||
import android.app.Application
|
||||
import android.app.NotificationChannel
|
||||
import android.app.NotificationManager
|
||||
import android.os.Build
|
||||
|
||||
/**
|
||||
* Application entry point for WarzonePhone.
|
||||
*
|
||||
* Creates the notification channel required for the foreground [com.wzp.service.CallService].
|
||||
*/
|
||||
class WzpApplication : Application() {
|
||||
|
||||
override fun onCreate() {
|
||||
super.onCreate()
|
||||
createNotificationChannel()
|
||||
}
|
||||
|
||||
private fun createNotificationChannel() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
val channel = NotificationChannel(
|
||||
CHANNEL_ID,
|
||||
"Active Call",
|
||||
NotificationManager.IMPORTANCE_LOW
|
||||
).apply {
|
||||
description = "Shown while a VoIP call is in progress"
|
||||
setShowBadge(false)
|
||||
}
|
||||
val nm = getSystemService(NotificationManager::class.java)
|
||||
nm.createNotificationChannel(channel)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val CHANNEL_ID = "wzp_call_channel"
|
||||
}
|
||||
}
|
||||
324
android/app/src/main/java/com/wzp/audio/AudioPipeline.kt
Normal file
324
android/app/src/main/java/com/wzp/audio/AudioPipeline.kt
Normal file
@@ -0,0 +1,324 @@
|
||||
package com.wzp.audio
|
||||
|
||||
import android.Manifest
|
||||
import android.content.Context
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.AudioTrack
|
||||
import android.media.MediaRecorder
|
||||
import android.media.audiofx.AcousticEchoCanceler
|
||||
import android.media.audiofx.NoiseSuppressor
|
||||
import android.util.Log
|
||||
import androidx.core.content.ContextCompat
|
||||
import com.wzp.engine.WzpEngine
|
||||
import java.io.BufferedOutputStream
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.OutputStreamWriter
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* Audio pipeline that captures mic audio and plays received audio using
|
||||
* Android AudioRecord/AudioTrack APIs running on JVM threads.
|
||||
*
|
||||
* PCM samples are shuttled to/from the Rust engine via JNI ring buffers:
|
||||
* - Capture: AudioRecord → WzpEngine.writeAudio() → Rust encoder → network
|
||||
* - Playout: network → Rust decoder → WzpEngine.readAudio() → AudioTrack
|
||||
*
|
||||
* All audio is 48kHz, mono, 16-bit PCM (matching Opus codec requirements).
|
||||
*/
|
||||
class AudioPipeline(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "AudioPipeline"
|
||||
private const val SAMPLE_RATE = 48000
|
||||
private const val CHANNEL_IN = AudioFormat.CHANNEL_IN_MONO
|
||||
private const val CHANNEL_OUT = AudioFormat.CHANNEL_OUT_MONO
|
||||
private const val ENCODING = AudioFormat.ENCODING_PCM_16BIT
|
||||
/** 20ms frame at 48kHz = 960 samples */
|
||||
private const val FRAME_SAMPLES = 960
|
||||
}
|
||||
|
||||
@Volatile
|
||||
private var running = false
|
||||
/** Playout (incoming voice) gain in dB. 0 = unity. */
|
||||
@Volatile
|
||||
var playoutGainDb: Float = 0f
|
||||
/** Capture (mic) gain in dB. 0 = unity. */
|
||||
@Volatile
|
||||
var captureGainDb: Float = 0f
|
||||
/** Whether to attach hardware AEC. Must be set before start(). */
|
||||
var aecEnabled: Boolean = true
|
||||
/** Enable debug recording of PCM + RMS histogram to cache dir. */
|
||||
var debugRecording: Boolean = true
|
||||
private var captureThread: Thread? = null
|
||||
private var playoutThread: Thread? = null
|
||||
|
||||
private val debugDir: File by lazy {
|
||||
File(context.cacheDir, "wzp_debug").also { it.mkdirs() }
|
||||
}
|
||||
|
||||
fun start(engine: WzpEngine) {
|
||||
if (running) return
|
||||
running = true
|
||||
|
||||
captureThread = Thread({
|
||||
runCapture(engine)
|
||||
// Park thread forever — exiting triggers a libcrypto TLS destructor
|
||||
// crash (SIGSEGV in OPENSSL_free) on Android when a JNI-calling thread exits.
|
||||
parkThread()
|
||||
}, "wzp-capture").apply {
|
||||
isDaemon = true
|
||||
priority = Thread.MAX_PRIORITY
|
||||
start()
|
||||
}
|
||||
|
||||
playoutThread = Thread({
|
||||
runPlayout(engine)
|
||||
parkThread()
|
||||
}, "wzp-playout").apply {
|
||||
isDaemon = true
|
||||
priority = Thread.MAX_PRIORITY
|
||||
start()
|
||||
}
|
||||
|
||||
Log.i(TAG, "audio pipeline started")
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
running = false
|
||||
// Don't join — threads are parked as daemons to avoid native TLS crash
|
||||
captureThread = null
|
||||
playoutThread = null
|
||||
Log.i(TAG, "audio pipeline stopped")
|
||||
}
|
||||
|
||||
private fun applyGain(pcm: ShortArray, count: Int, db: Float) {
|
||||
if (db == 0f) return
|
||||
val linear = 10f.pow(db / 20f)
|
||||
for (i in 0 until count) {
|
||||
pcm[i] = (pcm[i] * linear).toInt().coerceIn(-32000, 32000).toShort()
|
||||
}
|
||||
}
|
||||
|
||||
private fun computeRms(pcm: ShortArray, count: Int): Int {
|
||||
var sumSq = 0.0
|
||||
for (i in 0 until count) {
|
||||
val s = pcm[i].toDouble()
|
||||
sumSq += s * s
|
||||
}
|
||||
return sqrt(sumSq / count).toInt()
|
||||
}
|
||||
|
||||
private fun parkThread() {
|
||||
try {
|
||||
Thread.sleep(Long.MAX_VALUE)
|
||||
} catch (_: InterruptedException) {
|
||||
// process exiting
|
||||
}
|
||||
}
|
||||
|
||||
private fun runCapture(engine: WzpEngine) {
|
||||
if (ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO)
|
||||
!= PackageManager.PERMISSION_GRANTED
|
||||
) {
|
||||
Log.e(TAG, "RECORD_AUDIO permission not granted, capture disabled")
|
||||
return
|
||||
}
|
||||
|
||||
val minBuf = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_IN, ENCODING)
|
||||
val bufSize = maxOf(minBuf, FRAME_SAMPLES * 2 * 4) // at least 4 frames
|
||||
|
||||
val recorder = try {
|
||||
AudioRecord(
|
||||
MediaRecorder.AudioSource.VOICE_COMMUNICATION,
|
||||
SAMPLE_RATE,
|
||||
CHANNEL_IN,
|
||||
ENCODING,
|
||||
bufSize
|
||||
)
|
||||
} catch (e: SecurityException) {
|
||||
Log.e(TAG, "AudioRecord SecurityException: ${e.message}")
|
||||
return
|
||||
}
|
||||
|
||||
if (recorder.state != AudioRecord.STATE_INITIALIZED) {
|
||||
Log.e(TAG, "AudioRecord failed to initialize")
|
||||
recorder.release()
|
||||
return
|
||||
}
|
||||
|
||||
// Attach hardware AEC if available and enabled in settings
|
||||
var aec: AcousticEchoCanceler? = null
|
||||
var ns: NoiseSuppressor? = null
|
||||
if (aecEnabled) {
|
||||
if (AcousticEchoCanceler.isAvailable()) {
|
||||
try {
|
||||
aec = AcousticEchoCanceler.create(recorder.audioSessionId)
|
||||
aec?.enabled = true
|
||||
Log.i(TAG, "AEC enabled (session=${recorder.audioSessionId})")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "AEC init failed: ${e.message}")
|
||||
}
|
||||
} else {
|
||||
Log.w(TAG, "AEC not available on this device")
|
||||
}
|
||||
|
||||
// Attach hardware noise suppressor if available
|
||||
if (NoiseSuppressor.isAvailable()) {
|
||||
try {
|
||||
ns = NoiseSuppressor.create(recorder.audioSessionId)
|
||||
ns?.enabled = true
|
||||
Log.i(TAG, "NoiseSuppressor enabled")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "NoiseSuppressor init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "AEC disabled by user setting")
|
||||
}
|
||||
|
||||
recorder.startRecording()
|
||||
Log.i(TAG, "capture started: ${SAMPLE_RATE}Hz mono, buf=$bufSize, aec=${aec?.enabled}, ns=${ns?.enabled}")
|
||||
|
||||
val pcm = ShortArray(FRAME_SAMPLES)
|
||||
// Debug: PCM file + RMS CSV
|
||||
var pcmOut: BufferedOutputStream? = null
|
||||
var rmsCsv: OutputStreamWriter? = null
|
||||
val byteConv = ByteBuffer.allocate(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
var frameIdx = 0L
|
||||
if (debugRecording) {
|
||||
try {
|
||||
pcmOut = BufferedOutputStream(FileOutputStream(File(debugDir, "capture.pcm")), 65536)
|
||||
rmsCsv = OutputStreamWriter(FileOutputStream(File(debugDir, "capture_rms.csv")))
|
||||
rmsCsv.write("frame,time_ms,rms\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "debug recording init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
try {
|
||||
while (running) {
|
||||
val read = recorder.read(pcm, 0, FRAME_SAMPLES)
|
||||
if (read > 0) {
|
||||
applyGain(pcm, read, captureGainDb)
|
||||
engine.writeAudio(pcm)
|
||||
|
||||
// Debug: write raw PCM + RMS
|
||||
if (pcmOut != null) {
|
||||
byteConv.clear()
|
||||
for (i in 0 until read) byteConv.putShort(pcm[i])
|
||||
pcmOut.write(byteConv.array(), 0, read * 2)
|
||||
}
|
||||
if (rmsCsv != null) {
|
||||
val rms = computeRms(pcm, read)
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,$rms\n")
|
||||
}
|
||||
frameIdx++
|
||||
} else if (read < 0) {
|
||||
Log.e(TAG, "AudioRecord.read error: $read")
|
||||
break
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
pcmOut?.close()
|
||||
rmsCsv?.close()
|
||||
recorder.stop()
|
||||
aec?.release()
|
||||
ns?.release()
|
||||
recorder.release()
|
||||
Log.i(TAG, "capture stopped (frames=$frameIdx)")
|
||||
}
|
||||
}
|
||||
|
||||
private fun runPlayout(engine: WzpEngine) {
|
||||
val minBuf = AudioTrack.getMinBufferSize(SAMPLE_RATE, CHANNEL_OUT, ENCODING)
|
||||
val bufSize = maxOf(minBuf, FRAME_SAMPLES * 2 * 4)
|
||||
|
||||
val track = AudioTrack.Builder()
|
||||
.setAudioAttributes(
|
||||
AudioAttributes.Builder()
|
||||
.setUsage(AudioAttributes.USAGE_VOICE_COMMUNICATION)
|
||||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.build()
|
||||
)
|
||||
.setAudioFormat(
|
||||
AudioFormat.Builder()
|
||||
.setSampleRate(SAMPLE_RATE)
|
||||
.setChannelMask(CHANNEL_OUT)
|
||||
.setEncoding(ENCODING)
|
||||
.build()
|
||||
)
|
||||
.setBufferSizeInBytes(bufSize)
|
||||
.setTransferMode(AudioTrack.MODE_STREAM)
|
||||
.build()
|
||||
|
||||
if (track.state != AudioTrack.STATE_INITIALIZED) {
|
||||
Log.e(TAG, "AudioTrack failed to initialize")
|
||||
track.release()
|
||||
return
|
||||
}
|
||||
|
||||
track.play()
|
||||
Log.i(TAG, "playout started: ${SAMPLE_RATE}Hz mono, buf=$bufSize")
|
||||
|
||||
val pcm = ShortArray(FRAME_SAMPLES)
|
||||
val silence = ShortArray(FRAME_SAMPLES)
|
||||
// Debug: PCM file + RMS CSV for playout
|
||||
var pcmOut: BufferedOutputStream? = null
|
||||
var rmsCsv: OutputStreamWriter? = null
|
||||
val byteConv = ByteBuffer.allocate(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
var frameIdx = 0L
|
||||
if (debugRecording) {
|
||||
try {
|
||||
pcmOut = BufferedOutputStream(FileOutputStream(File(debugDir, "playout.pcm")), 65536)
|
||||
rmsCsv = OutputStreamWriter(FileOutputStream(File(debugDir, "playout_rms.csv")))
|
||||
rmsCsv.write("frame,time_ms,rms\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "debug playout recording init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
try {
|
||||
while (running) {
|
||||
val read = engine.readAudio(pcm)
|
||||
if (read >= FRAME_SAMPLES) {
|
||||
applyGain(pcm, read, playoutGainDb)
|
||||
track.write(pcm, 0, read)
|
||||
|
||||
// Debug: write raw PCM + RMS
|
||||
if (pcmOut != null) {
|
||||
byteConv.clear()
|
||||
for (i in 0 until read) byteConv.putShort(pcm[i])
|
||||
pcmOut.write(byteConv.array(), 0, read * 2)
|
||||
}
|
||||
if (rmsCsv != null) {
|
||||
val rms = computeRms(pcm, read)
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,$rms\n")
|
||||
}
|
||||
frameIdx++
|
||||
} else {
|
||||
track.write(silence, 0, FRAME_SAMPLES)
|
||||
// Log silence frames to RMS as 0
|
||||
if (rmsCsv != null) {
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,0\n")
|
||||
}
|
||||
frameIdx++
|
||||
Thread.sleep(5)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
pcmOut?.close()
|
||||
rmsCsv?.close()
|
||||
track.stop()
|
||||
track.release()
|
||||
Log.i(TAG, "playout stopped (frames=$frameIdx)")
|
||||
}
|
||||
}
|
||||
}
|
||||
142
android/app/src/main/java/com/wzp/audio/AudioRouteManager.kt
Normal file
142
android/app/src/main/java/com/wzp/audio/AudioRouteManager.kt
Normal file
@@ -0,0 +1,142 @@
|
||||
package com.wzp.audio
|
||||
|
||||
import android.content.Context
|
||||
import android.media.AudioDeviceCallback
|
||||
import android.media.AudioDeviceInfo
|
||||
import android.media.AudioManager
|
||||
import android.os.Handler
|
||||
import android.os.Looper
|
||||
|
||||
/**
|
||||
* Manages audio routing between earpiece, speaker, and Bluetooth devices.
|
||||
*
|
||||
* Wraps [AudioManager] operations and listens for device connection changes
|
||||
* via [AudioDeviceCallback] (API 23+).
|
||||
*
|
||||
* Usage:
|
||||
* 1. Call [register] when the call starts
|
||||
* 2. Use [setSpeaker] and [setBluetoothSco] to switch routes
|
||||
* 3. Call [unregister] when the call ends
|
||||
*/
|
||||
class AudioRouteManager(context: Context) {
|
||||
|
||||
private val audioManager = context.getSystemService(Context.AUDIO_SERVICE) as AudioManager
|
||||
private val mainHandler = Handler(Looper.getMainLooper())
|
||||
|
||||
/** Listener for audio route changes. */
|
||||
var onRouteChanged: ((AudioRoute) -> Unit)? = null
|
||||
|
||||
/** Current active route. */
|
||||
var currentRoute: AudioRoute = AudioRoute.EARPIECE
|
||||
private set
|
||||
|
||||
// -- Device callback (API 23+) -------------------------------------------
|
||||
|
||||
private val deviceCallback = object : AudioDeviceCallback() {
|
||||
override fun onAudioDevicesAdded(addedDevices: Array<out AudioDeviceInfo>) {
|
||||
for (device in addedDevices) {
|
||||
if (device.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO) {
|
||||
// A Bluetooth headset was connected — optionally auto-switch
|
||||
onRouteChanged?.invoke(AudioRoute.BLUETOOTH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onAudioDevicesRemoved(removedDevices: Array<out AudioDeviceInfo>) {
|
||||
for (device in removedDevices) {
|
||||
if (device.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO) {
|
||||
// Bluetooth disconnected — fall back to earpiece or speaker
|
||||
val fallback = if (audioManager.isSpeakerphoneOn) {
|
||||
AudioRoute.SPEAKER
|
||||
} else {
|
||||
AudioRoute.EARPIECE
|
||||
}
|
||||
currentRoute = fallback
|
||||
onRouteChanged?.invoke(fallback)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Public API -----------------------------------------------------------
|
||||
|
||||
/** Register the device callback. Call when a call starts. */
|
||||
fun register() {
|
||||
audioManager.registerAudioDeviceCallback(deviceCallback, mainHandler)
|
||||
}
|
||||
|
||||
/** Unregister the device callback and release Bluetooth SCO. Call when the call ends. */
|
||||
fun unregister() {
|
||||
audioManager.unregisterAudioDeviceCallback(deviceCallback)
|
||||
stopBluetoothSco()
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable the loudspeaker.
|
||||
*
|
||||
* When enabling speaker, Bluetooth SCO is disconnected.
|
||||
*/
|
||||
@Suppress("DEPRECATION")
|
||||
fun setSpeaker(enabled: Boolean) {
|
||||
if (enabled) {
|
||||
stopBluetoothSco()
|
||||
}
|
||||
audioManager.isSpeakerphoneOn = enabled
|
||||
currentRoute = if (enabled) AudioRoute.SPEAKER else AudioRoute.EARPIECE
|
||||
onRouteChanged?.invoke(currentRoute)
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable Bluetooth SCO (Synchronous Connection Oriented) audio.
|
||||
*
|
||||
* When enabling Bluetooth, the speaker is turned off.
|
||||
*/
|
||||
@Suppress("DEPRECATION")
|
||||
fun setBluetoothSco(enabled: Boolean) {
|
||||
if (enabled) {
|
||||
audioManager.isSpeakerphoneOn = false
|
||||
audioManager.startBluetoothSco()
|
||||
audioManager.isBluetoothScoOn = true
|
||||
currentRoute = AudioRoute.BLUETOOTH
|
||||
} else {
|
||||
stopBluetoothSco()
|
||||
currentRoute = AudioRoute.EARPIECE
|
||||
}
|
||||
onRouteChanged?.invoke(currentRoute)
|
||||
}
|
||||
|
||||
/** Check whether a Bluetooth SCO device is currently connected. */
|
||||
fun isBluetoothAvailable(): Boolean {
|
||||
val devices = audioManager.getDevices(AudioManager.GET_DEVICES_OUTPUTS)
|
||||
return devices.any { it.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO }
|
||||
}
|
||||
|
||||
/** List available output audio routes. */
|
||||
fun availableRoutes(): List<AudioRoute> {
|
||||
val routes = mutableListOf(AudioRoute.EARPIECE, AudioRoute.SPEAKER)
|
||||
if (isBluetoothAvailable()) {
|
||||
routes.add(AudioRoute.BLUETOOTH)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// -- Internal -------------------------------------------------------------
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
private fun stopBluetoothSco() {
|
||||
if (audioManager.isBluetoothScoOn) {
|
||||
audioManager.isBluetoothScoOn = false
|
||||
audioManager.stopBluetoothSco()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Audio output route. */
|
||||
enum class AudioRoute {
|
||||
/** Phone earpiece (default for calls). */
|
||||
EARPIECE,
|
||||
/** Built-in loudspeaker. */
|
||||
SPEAKER,
|
||||
/** Bluetooth SCO headset/headphones. */
|
||||
BLUETOOTH
|
||||
}
|
||||
141
android/app/src/main/java/com/wzp/data/SettingsRepository.kt
Normal file
141
android/app/src/main/java/com/wzp/data/SettingsRepository.kt
Normal file
@@ -0,0 +1,141 @@
|
||||
package com.wzp.data
|
||||
|
||||
import android.content.Context
|
||||
import android.content.SharedPreferences
|
||||
import com.wzp.ui.call.ServerEntry
|
||||
import org.json.JSONArray
|
||||
import org.json.JSONObject
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Persists user settings via SharedPreferences.
|
||||
*
|
||||
* Stores: servers, default server index, room name, alias, gain values,
|
||||
* IPv6 preference, and the identity seed (hex-encoded 32 bytes).
|
||||
*/
|
||||
class SettingsRepository(context: Context) {
|
||||
|
||||
private val prefs: SharedPreferences =
|
||||
context.applicationContext.getSharedPreferences("wzp_settings", Context.MODE_PRIVATE)
|
||||
|
||||
companion object {
|
||||
private const val KEY_SERVERS = "servers_json"
|
||||
private const val KEY_SELECTED_SERVER = "selected_server"
|
||||
private const val KEY_ROOM = "room_name"
|
||||
private const val KEY_ALIAS = "alias"
|
||||
private const val KEY_PLAYOUT_GAIN = "playout_gain_db"
|
||||
private const val KEY_CAPTURE_GAIN = "capture_gain_db"
|
||||
private const val KEY_PREFER_IPV6 = "prefer_ipv6"
|
||||
private const val KEY_IDENTITY_SEED = "identity_seed_hex"
|
||||
private const val KEY_AEC_ENABLED = "aec_enabled"
|
||||
}
|
||||
|
||||
// --- Servers ---
|
||||
|
||||
fun saveServers(servers: List<ServerEntry>) {
|
||||
val arr = JSONArray()
|
||||
servers.forEach { entry ->
|
||||
arr.put(JSONObject().apply {
|
||||
put("address", entry.address)
|
||||
put("label", entry.label)
|
||||
})
|
||||
}
|
||||
prefs.edit().putString(KEY_SERVERS, arr.toString()).apply()
|
||||
}
|
||||
|
||||
fun loadServers(): List<ServerEntry>? {
|
||||
val json = prefs.getString(KEY_SERVERS, null) ?: return null
|
||||
return try {
|
||||
val arr = JSONArray(json)
|
||||
(0 until arr.length()).map { i ->
|
||||
val obj = arr.getJSONObject(i)
|
||||
ServerEntry(obj.getString("address"), obj.getString("label"))
|
||||
}
|
||||
} catch (_: Exception) { null }
|
||||
}
|
||||
|
||||
fun saveSelectedServer(index: Int) {
|
||||
prefs.edit().putInt(KEY_SELECTED_SERVER, index).apply()
|
||||
}
|
||||
|
||||
fun loadSelectedServer(): Int = prefs.getInt(KEY_SELECTED_SERVER, 0)
|
||||
|
||||
// --- Room ---
|
||||
|
||||
fun saveRoom(name: String) { prefs.edit().putString(KEY_ROOM, name).apply() }
|
||||
fun loadRoom(): String = prefs.getString(KEY_ROOM, "android") ?: "android"
|
||||
|
||||
// --- Alias ---
|
||||
|
||||
fun saveAlias(alias: String) { prefs.edit().putString(KEY_ALIAS, alias).apply() }
|
||||
|
||||
/**
|
||||
* Load alias, generating a random name on first launch.
|
||||
*/
|
||||
fun getOrCreateAlias(): String {
|
||||
val existing = prefs.getString(KEY_ALIAS, null)
|
||||
if (!existing.isNullOrEmpty()) return existing
|
||||
val name = generateRandomName()
|
||||
prefs.edit().putString(KEY_ALIAS, name).apply()
|
||||
return name
|
||||
}
|
||||
|
||||
private fun generateRandomName(): String {
|
||||
val adjectives = listOf(
|
||||
"Swift", "Silent", "Brave", "Calm", "Dark", "Fierce", "Ghost",
|
||||
"Iron", "Lucky", "Noble", "Quick", "Sharp", "Storm", "Wild",
|
||||
"Cold", "Bright", "Lone", "Red", "Grey", "Frosty", "Dusty",
|
||||
"Rusty", "Neon", "Void", "Solar", "Lunar", "Cyber", "Pixel",
|
||||
"Sonic", "Hyper", "Turbo", "Nano", "Mega", "Ultra", "Zinc"
|
||||
)
|
||||
val nouns = listOf(
|
||||
"Wolf", "Hawk", "Fox", "Bear", "Lynx", "Crow", "Viper",
|
||||
"Cobra", "Tiger", "Eagle", "Shark", "Raven", "Falcon", "Otter",
|
||||
"Mantis", "Panda", "Jackal", "Badger", "Heron", "Bison",
|
||||
"Condor", "Coyote", "Gecko", "Hornet", "Marten", "Osprey",
|
||||
"Parrot", "Puma", "Raptor", "Stork", "Toucan", "Walrus"
|
||||
)
|
||||
val adj = adjectives.random()
|
||||
val noun = nouns.random()
|
||||
return "$adj $noun"
|
||||
}
|
||||
|
||||
// --- Gain ---
|
||||
|
||||
fun savePlayoutGain(db: Float) { prefs.edit().putFloat(KEY_PLAYOUT_GAIN, db).apply() }
|
||||
fun loadPlayoutGain(): Float = prefs.getFloat(KEY_PLAYOUT_GAIN, 0f)
|
||||
|
||||
fun saveCaptureGain(db: Float) { prefs.edit().putFloat(KEY_CAPTURE_GAIN, db).apply() }
|
||||
fun loadCaptureGain(): Float = prefs.getFloat(KEY_CAPTURE_GAIN, 0f)
|
||||
|
||||
// --- IPv6 ---
|
||||
|
||||
fun savePreferIPv6(prefer: Boolean) { prefs.edit().putBoolean(KEY_PREFER_IPV6, prefer).apply() }
|
||||
fun loadPreferIPv6(): Boolean = prefs.getBoolean(KEY_PREFER_IPV6, false)
|
||||
|
||||
// --- AEC ---
|
||||
|
||||
fun saveAecEnabled(enabled: Boolean) { prefs.edit().putBoolean(KEY_AEC_ENABLED, enabled).apply() }
|
||||
fun loadAecEnabled(): Boolean = prefs.getBoolean(KEY_AEC_ENABLED, true)
|
||||
|
||||
// --- Identity seed ---
|
||||
|
||||
/**
|
||||
* Get or generate the identity seed. On first call, generates a random
|
||||
* 32-byte seed and persists it. Subsequent calls return the same seed.
|
||||
*/
|
||||
fun getOrCreateSeedHex(): String {
|
||||
val existing = prefs.getString(KEY_IDENTITY_SEED, null)
|
||||
if (!existing.isNullOrEmpty()) return existing
|
||||
val seed = ByteArray(32).also { SecureRandom().nextBytes(it) }
|
||||
val hex = seed.joinToString("") { "%02x".format(it) }
|
||||
prefs.edit().putString(KEY_IDENTITY_SEED, hex).apply()
|
||||
return hex
|
||||
}
|
||||
|
||||
fun loadSeedHex(): String = prefs.getString(KEY_IDENTITY_SEED, "") ?: ""
|
||||
|
||||
fun saveSeedHex(hex: String) {
|
||||
prefs.edit().putString(KEY_IDENTITY_SEED, hex).apply()
|
||||
}
|
||||
}
|
||||
198
android/app/src/main/java/com/wzp/debug/DebugReporter.kt
Normal file
198
android/app/src/main/java/com/wzp/debug/DebugReporter.kt
Normal file
@@ -0,0 +1,198 @@
|
||||
package com.wzp.debug
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.BufferedOutputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.File
|
||||
import java.io.FileInputStream
|
||||
import java.io.FileOutputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import java.util.zip.ZipEntry
|
||||
import java.util.zip.ZipOutputStream
|
||||
|
||||
/**
|
||||
* Collects call debug data (audio recordings, logs, histograms, stats)
|
||||
* into a zip file for email sharing.
|
||||
*/
|
||||
class DebugReporter(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "DebugReporter"
|
||||
private const val SAMPLE_RATE = 48000
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a zip with all debug data.
|
||||
* Returns the zip File on success, or null on failure.
|
||||
*/
|
||||
suspend fun collectZip(
|
||||
callDurationSecs: Double,
|
||||
finalStatsJson: String,
|
||||
aecEnabled: Boolean,
|
||||
alias: String,
|
||||
server: String,
|
||||
room: String
|
||||
): File? = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val debugDir = File(context.cacheDir, "wzp_debug")
|
||||
val timestamp = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.US).format(Date())
|
||||
val zipFile = File(context.cacheDir, "wzp_debug_${timestamp}.zip")
|
||||
|
||||
ZipOutputStream(BufferedOutputStream(FileOutputStream(zipFile))).use { zos ->
|
||||
// 1. Call metadata
|
||||
val meta = buildString {
|
||||
appendLine("=== WZ Phone Debug Report ===")
|
||||
appendLine("Timestamp: $timestamp")
|
||||
appendLine("Alias: $alias")
|
||||
appendLine("Server: $server")
|
||||
appendLine("Room: $room")
|
||||
appendLine("Duration: ${"%.1f".format(callDurationSecs)}s")
|
||||
appendLine("AEC: ${if (aecEnabled) "ON" else "OFF"}")
|
||||
appendLine("Device: ${android.os.Build.MANUFACTURER} ${android.os.Build.MODEL}")
|
||||
appendLine("Android: ${android.os.Build.VERSION.RELEASE} (API ${android.os.Build.VERSION.SDK_INT})")
|
||||
appendLine()
|
||||
appendLine("=== Final Stats ===")
|
||||
appendLine(finalStatsJson)
|
||||
}
|
||||
addTextEntry(zos, "meta.txt", meta)
|
||||
|
||||
// 2. Logcat — WZP-related tags
|
||||
val logcat = collectLogcat()
|
||||
addTextEntry(zos, "logcat.txt", logcat)
|
||||
|
||||
// 3. Capture audio (mic) → WAV
|
||||
val captureRaw = File(debugDir, "capture.pcm")
|
||||
if (captureRaw.exists() && captureRaw.length() > 0) {
|
||||
addWavEntry(zos, "capture.wav", captureRaw)
|
||||
Log.i(TAG, "capture.pcm: ${captureRaw.length()} bytes -> WAV")
|
||||
}
|
||||
|
||||
// 4. Playout audio (speaker) → WAV
|
||||
val playoutRaw = File(debugDir, "playout.pcm")
|
||||
if (playoutRaw.exists() && playoutRaw.length() > 0) {
|
||||
addWavEntry(zos, "playout.wav", playoutRaw)
|
||||
Log.i(TAG, "playout.pcm: ${playoutRaw.length()} bytes -> WAV")
|
||||
}
|
||||
|
||||
// 5. RMS histogram CSV
|
||||
val captureHist = File(debugDir, "capture_rms.csv")
|
||||
if (captureHist.exists()) addFileEntry(zos, "capture_rms.csv", captureHist)
|
||||
val playoutHist = File(debugDir, "playout_rms.csv")
|
||||
if (playoutHist.exists()) addFileEntry(zos, "playout_rms.csv", playoutHist)
|
||||
}
|
||||
|
||||
Log.i(TAG, "zip created: ${zipFile.length()} bytes (${zipFile.length() / 1024}KB)")
|
||||
|
||||
// Clean up raw debug files (keep zip)
|
||||
debugDir.listFiles()?.forEach { it.delete() }
|
||||
|
||||
zipFile
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "debug report failed", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/** Clean up any leftover debug files from a previous session. */
|
||||
fun prepareForCall() {
|
||||
val debugDir = File(context.cacheDir, "wzp_debug")
|
||||
if (debugDir.exists()) {
|
||||
debugDir.listFiles()?.forEach { it.delete() }
|
||||
}
|
||||
debugDir.mkdirs()
|
||||
// Also clean up old zip files
|
||||
context.cacheDir.listFiles()?.filter { it.name.startsWith("wzp_debug_") }?.forEach { it.delete() }
|
||||
}
|
||||
|
||||
private fun collectLogcat(): String {
|
||||
return try {
|
||||
val process = Runtime.getRuntime().exec(
|
||||
arrayOf(
|
||||
"logcat", "-d",
|
||||
"-t", "5000",
|
||||
"--format", "threadtime"
|
||||
)
|
||||
)
|
||||
val output = process.inputStream.bufferedReader().readText()
|
||||
process.waitFor()
|
||||
output.lines()
|
||||
.filter { line ->
|
||||
line.contains("wzp", ignoreCase = true) ||
|
||||
line.contains("WzpEngine") ||
|
||||
line.contains("AudioPipeline") ||
|
||||
line.contains("WzpCall") ||
|
||||
line.contains("CallService") ||
|
||||
line.contains("AudioTrack") ||
|
||||
line.contains("AudioRecord") ||
|
||||
line.contains("AcousticEchoCanceler") ||
|
||||
line.contains("NoiseSuppressor") ||
|
||||
line.contains("FATAL") ||
|
||||
line.contains("ANR") ||
|
||||
line.contains("AudioFlinger") ||
|
||||
line.contains("DebugReporter") ||
|
||||
line.contains("QUIC") ||
|
||||
line.contains("quinn") ||
|
||||
line.contains("send task") ||
|
||||
line.contains("recv task") ||
|
||||
line.contains("send stats") ||
|
||||
line.contains("recv stats") ||
|
||||
line.contains("send_media") ||
|
||||
line.contains("FEC block") ||
|
||||
line.contains("recv gap") ||
|
||||
line.contains("frames_dropped") ||
|
||||
line.contains("opus")
|
||||
}
|
||||
.joinToString("\n")
|
||||
} catch (e: Exception) {
|
||||
"Failed to collect logcat: ${e.message}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun addWavEntry(zos: ZipOutputStream, name: String, pcmFile: File) {
|
||||
val dataSize = pcmFile.length().toInt()
|
||||
val byteRate = SAMPLE_RATE * 1 * 16 / 8
|
||||
val blockAlign = 1 * 16 / 8
|
||||
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
|
||||
// Write WAV header (44 bytes)
|
||||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray())
|
||||
header.putInt(36 + dataSize)
|
||||
header.put("WAVE".toByteArray())
|
||||
header.put("fmt ".toByteArray())
|
||||
header.putInt(16)
|
||||
header.putShort(1) // PCM
|
||||
header.putShort(1) // mono
|
||||
header.putInt(SAMPLE_RATE)
|
||||
header.putInt(byteRate)
|
||||
header.putShort(blockAlign.toShort())
|
||||
header.putShort(16) // bits per sample
|
||||
header.put("data".toByteArray())
|
||||
header.putInt(dataSize)
|
||||
zos.write(header.array())
|
||||
|
||||
// Stream PCM data directly (avoids loading entire file into memory)
|
||||
FileInputStream(pcmFile).use { it.copyTo(zos) }
|
||||
zos.closeEntry()
|
||||
}
|
||||
|
||||
private fun addTextEntry(zos: ZipOutputStream, name: String, content: String) {
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
zos.write(content.toByteArray())
|
||||
zos.closeEntry()
|
||||
}
|
||||
|
||||
private fun addFileEntry(zos: ZipOutputStream, name: String, file: File) {
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
FileInputStream(file).use { it.copyTo(zos) }
|
||||
zos.closeEntry()
|
||||
}
|
||||
}
|
||||
97
android/app/src/main/java/com/wzp/engine/CallStats.kt
Normal file
97
android/app/src/main/java/com/wzp/engine/CallStats.kt
Normal file
@@ -0,0 +1,97 @@
|
||||
package com.wzp.engine
|
||||
|
||||
import org.json.JSONArray
|
||||
import org.json.JSONObject
|
||||
|
||||
/**
|
||||
* Snapshot of call statistics, mirroring the Rust `CallStats` struct.
|
||||
*
|
||||
* Constructed from the JSON string returned by [WzpEngine.getStats].
|
||||
*/
|
||||
data class CallStats(
|
||||
/** Current call state ordinal (see [CallStateConstants]). */
|
||||
val state: Int = 0,
|
||||
/** Call duration in seconds. */
|
||||
val durationSecs: Double = 0.0,
|
||||
/** Quality tier: 0 = Good, 1 = Degraded, 2 = Catastrophic. */
|
||||
val qualityTier: Int = 0,
|
||||
/** Observed packet loss percentage (0..100). */
|
||||
val lossPct: Float = 0f,
|
||||
/** Smoothed round-trip time in milliseconds. */
|
||||
val rttMs: Int = 0,
|
||||
/** Jitter in milliseconds. */
|
||||
val jitterMs: Int = 0,
|
||||
/** Current jitter buffer depth in packets. */
|
||||
val jitterBufferDepth: Int = 0,
|
||||
/** Total frames encoded since call start. */
|
||||
val framesEncoded: Long = 0,
|
||||
/** Total frames decoded since call start. */
|
||||
val framesDecoded: Long = 0,
|
||||
/** Number of playout underruns (buffer empty when audio was needed). */
|
||||
val underruns: Long = 0,
|
||||
/** Frames recovered by FEC. */
|
||||
val fecRecovered: Long = 0,
|
||||
/** Current mic audio level (RMS, 0-32767). */
|
||||
val audioLevel: Int = 0,
|
||||
/** Number of participants in the room. */
|
||||
val roomParticipantCount: Int = 0,
|
||||
/** Participants in the room (fingerprint + optional alias). */
|
||||
val roomParticipants: List<RoomMember> = emptyList(),
|
||||
) {
|
||||
/** Human-readable quality label. */
|
||||
val qualityLabel: String
|
||||
get() = when (qualityTier) {
|
||||
0 -> "Good"
|
||||
1 -> "Degraded"
|
||||
2 -> "Catastrophic"
|
||||
else -> "Unknown"
|
||||
}
|
||||
|
||||
companion object {
|
||||
private fun parseParticipants(arr: JSONArray?): List<RoomMember> {
|
||||
if (arr == null) return emptyList()
|
||||
return (0 until arr.length()).map { i ->
|
||||
val o = arr.getJSONObject(i)
|
||||
RoomMember(
|
||||
fingerprint = o.optString("fingerprint", ""),
|
||||
alias = if (o.isNull("alias")) null else o.optString("alias", null)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Deserialise from the JSON string produced by the native engine. */
|
||||
fun fromJson(json: String): CallStats {
|
||||
return try {
|
||||
val obj = JSONObject(json)
|
||||
CallStats(
|
||||
state = obj.optInt("state", 0),
|
||||
durationSecs = obj.optDouble("duration_secs", 0.0),
|
||||
qualityTier = obj.optInt("quality_tier", 0),
|
||||
lossPct = obj.optDouble("loss_pct", 0.0).toFloat(),
|
||||
rttMs = obj.optInt("rtt_ms", 0),
|
||||
jitterMs = obj.optInt("jitter_ms", 0),
|
||||
jitterBufferDepth = obj.optInt("jitter_buffer_depth", 0),
|
||||
framesEncoded = obj.optLong("frames_encoded", 0),
|
||||
framesDecoded = obj.optLong("frames_decoded", 0),
|
||||
underruns = obj.optLong("underruns", 0),
|
||||
fecRecovered = obj.optLong("fec_recovered", 0),
|
||||
audioLevel = obj.optInt("audio_level", 0),
|
||||
roomParticipantCount = obj.optInt("room_participant_count", 0),
|
||||
roomParticipants = parseParticipants(obj.optJSONArray("room_participants"))
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
CallStats()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class RoomMember(
|
||||
val fingerprint: String,
|
||||
val alias: String? = null
|
||||
) {
|
||||
/** Short display name: alias if set, otherwise first 8 chars of fingerprint. */
|
||||
val displayName: String
|
||||
get() = alias?.takeIf { it.isNotBlank() }
|
||||
?: fingerprint.take(8).ifEmpty { "unknown" }
|
||||
}
|
||||
32
android/app/src/main/java/com/wzp/engine/WzpCallback.kt
Normal file
32
android/app/src/main/java/com/wzp/engine/WzpCallback.kt
Normal file
@@ -0,0 +1,32 @@
|
||||
package com.wzp.engine
|
||||
|
||||
/**
|
||||
* Callback interface for VoIP engine events.
|
||||
*
|
||||
* All callbacks are invoked on the main/UI thread.
|
||||
*/
|
||||
interface WzpCallback {
|
||||
|
||||
/**
|
||||
* Called when the call state changes.
|
||||
*
|
||||
* @param state one of [CallStateConstants]: IDLE(0), CONNECTING(1), ACTIVE(2),
|
||||
* RECONNECTING(3), CLOSED(4)
|
||||
*/
|
||||
fun onCallStateChanged(state: Int)
|
||||
|
||||
/**
|
||||
* Called when the network quality tier changes.
|
||||
*
|
||||
* @param tier 0 = Good, 1 = Degraded, 2 = Catastrophic
|
||||
*/
|
||||
fun onQualityTierChanged(tier: Int)
|
||||
|
||||
/**
|
||||
* Called when an error occurs in the native engine.
|
||||
*
|
||||
* @param code numeric error code (negative)
|
||||
* @param message human-readable description
|
||||
*/
|
||||
fun onError(code: Int, message: String)
|
||||
}
|
||||
149
android/app/src/main/java/com/wzp/engine/WzpEngine.kt
Normal file
149
android/app/src/main/java/com/wzp/engine/WzpEngine.kt
Normal file
@@ -0,0 +1,149 @@
|
||||
package com.wzp.engine
|
||||
|
||||
/**
|
||||
* Native VoIP engine wrapper. Delegates all work to libwzp_android.so via JNI.
|
||||
*
|
||||
* Lifecycle:
|
||||
* 1. Construct with a [WzpCallback]
|
||||
* 2. Call [init] to create the native engine
|
||||
* 3. Call [startCall] to begin a VoIP session
|
||||
* 4. Use [setMute], [setSpeaker], [getStats], [forceProfile] during the call
|
||||
* 5. Call [stopCall] to end the session
|
||||
* 6. Call [destroy] when the engine is no longer needed
|
||||
*
|
||||
* Thread safety: all methods must be called from the same thread (typically main).
|
||||
*/
|
||||
class WzpEngine(private val callback: WzpCallback) {
|
||||
|
||||
/** Opaque pointer to the native EngineHandle. 0 means not initialised. */
|
||||
private var nativeHandle: Long = 0L
|
||||
|
||||
/** Whether the engine has been initialised. */
|
||||
val isInitialized: Boolean get() = nativeHandle != 0L
|
||||
|
||||
/** Create the native engine. Must be called before any other method. */
|
||||
fun init() {
|
||||
check(nativeHandle == 0L) { "Engine already initialized" }
|
||||
nativeHandle = nativeInit()
|
||||
check(nativeHandle != 0L) { "Native engine creation failed" }
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a call.
|
||||
*
|
||||
* @param relayAddr relay server address (host:port)
|
||||
* @param room room identifier (used as QUIC SNI)
|
||||
* @param seedHex 64-char hex-encoded 32-byte identity seed (empty = random)
|
||||
* @param token authentication token (empty = no auth)
|
||||
* @param alias display name sent to relay for room participant list
|
||||
* @return 0 on success, negative error code on failure
|
||||
*/
|
||||
fun startCall(relayAddr: String, room: String, seedHex: String = "", token: String = "", alias: String = ""): Int {
|
||||
check(nativeHandle != 0L) { "Engine not initialized" }
|
||||
val result = nativeStartCall(nativeHandle, relayAddr, room, seedHex, token, alias)
|
||||
if (result == 0) {
|
||||
callback.onCallStateChanged(CallStateConstants.CONNECTING)
|
||||
} else {
|
||||
callback.onError(result, "Failed to start call")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/** Stop the active call. Safe to call when no call is active. */
|
||||
fun stopCall() {
|
||||
if (nativeHandle != 0L) {
|
||||
nativeStopCall(nativeHandle)
|
||||
callback.onCallStateChanged(CallStateConstants.CLOSED)
|
||||
}
|
||||
}
|
||||
|
||||
/** Mute or unmute the microphone. */
|
||||
fun setMute(muted: Boolean) {
|
||||
if (nativeHandle != 0L) nativeSetMute(nativeHandle, muted)
|
||||
}
|
||||
|
||||
/** Enable or disable loudspeaker mode. */
|
||||
fun setSpeaker(speaker: Boolean) {
|
||||
if (nativeHandle != 0L) nativeSetSpeaker(nativeHandle, speaker)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get current call statistics as a JSON string.
|
||||
*
|
||||
* @return JSON-serialised [CallStats], or `"{}"` if the engine is not initialised.
|
||||
*/
|
||||
fun getStats(): String {
|
||||
if (nativeHandle == 0L) return "{}"
|
||||
return try {
|
||||
nativeGetStats(nativeHandle) ?: "{}"
|
||||
} catch (_: Exception) {
|
||||
"{}"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Force a quality profile, overriding adaptive selection.
|
||||
*
|
||||
* @param profile 0 = GOOD, 1 = DEGRADED, 2 = CATASTROPHIC
|
||||
*/
|
||||
fun forceProfile(profile: Int) {
|
||||
if (nativeHandle != 0L) nativeForceProfile(nativeHandle, profile)
|
||||
}
|
||||
|
||||
/** Destroy the native engine and free all resources. The instance must not be reused. */
|
||||
fun destroy() {
|
||||
if (nativeHandle != 0L) {
|
||||
nativeDestroy(nativeHandle)
|
||||
nativeHandle = 0L
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Write captured PCM samples into the engine's capture ring buffer.
|
||||
* Called from the AudioRecord capture thread.
|
||||
*/
|
||||
fun writeAudio(pcm: ShortArray): Int {
|
||||
if (nativeHandle == 0L) return 0
|
||||
return nativeWriteAudio(nativeHandle, pcm)
|
||||
}
|
||||
|
||||
/**
|
||||
* Read decoded PCM samples from the engine's playout ring buffer.
|
||||
* Called from the AudioTrack playout thread.
|
||||
*/
|
||||
fun readAudio(pcm: ShortArray): Int {
|
||||
if (nativeHandle == 0L) return 0
|
||||
return nativeReadAudio(nativeHandle, pcm)
|
||||
}
|
||||
|
||||
// -- JNI native methods --------------------------------------------------
|
||||
|
||||
private external fun nativeInit(): Long
|
||||
private external fun nativeStartCall(
|
||||
handle: Long, relay: String, room: String, seed: String, token: String, alias: String
|
||||
): Int
|
||||
private external fun nativeStopCall(handle: Long)
|
||||
private external fun nativeSetMute(handle: Long, muted: Boolean)
|
||||
private external fun nativeSetSpeaker(handle: Long, speaker: Boolean)
|
||||
private external fun nativeGetStats(handle: Long): String?
|
||||
private external fun nativeForceProfile(handle: Long, profile: Int)
|
||||
private external fun nativeWriteAudio(handle: Long, pcm: ShortArray): Int
|
||||
private external fun nativeReadAudio(handle: Long, pcm: ShortArray): Int
|
||||
private external fun nativeDestroy(handle: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("wzp_android")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Integer constants matching the Rust [CallState] enum ordinals. */
|
||||
object CallStateConstants {
|
||||
const val IDLE = 0
|
||||
const val CONNECTING = 1
|
||||
const val ACTIVE = 2
|
||||
const val RECONNECTING = 3
|
||||
const val CLOSED = 4
|
||||
}
|
||||
172
android/app/src/main/java/com/wzp/service/CallService.kt
Normal file
172
android/app/src/main/java/com/wzp/service/CallService.kt
Normal file
@@ -0,0 +1,172 @@
|
||||
package com.wzp.service
|
||||
|
||||
import android.app.Notification
|
||||
import android.app.PendingIntent
|
||||
import android.app.Service
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.media.AudioManager
|
||||
import android.net.wifi.WifiManager
|
||||
import android.os.IBinder
|
||||
import android.os.PowerManager
|
||||
import androidx.core.app.NotificationCompat
|
||||
import com.wzp.WzpApplication
|
||||
import com.wzp.ui.call.CallActivity
|
||||
|
||||
/**
|
||||
* Foreground service that keeps the VoIP call alive when the app is backgrounded.
|
||||
*
|
||||
* Responsibilities:
|
||||
* - Shows a persistent notification during the call
|
||||
* - Acquires a partial wake lock so the CPU stays on
|
||||
* - Acquires a Wi-Fi lock to prevent Wi-Fi from going to sleep
|
||||
* - Sets [AudioManager] mode to [AudioManager.MODE_IN_COMMUNICATION]
|
||||
* - Releases all resources when the call ends
|
||||
*/
|
||||
class CallService : Service() {
|
||||
|
||||
private var wakeLock: PowerManager.WakeLock? = null
|
||||
private var wifiLock: WifiManager.WifiLock? = null
|
||||
private var previousAudioMode: Int = AudioManager.MODE_NORMAL
|
||||
|
||||
// -- Lifecycle ------------------------------------------------------------
|
||||
|
||||
override fun onCreate() {
|
||||
super.onCreate()
|
||||
acquireWakeLock()
|
||||
acquireWifiLock()
|
||||
setAudioMode()
|
||||
}
|
||||
|
||||
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
|
||||
when (intent?.action) {
|
||||
ACTION_STOP -> {
|
||||
onStopFromNotification?.invoke()
|
||||
stopSelf()
|
||||
return START_NOT_STICKY
|
||||
}
|
||||
}
|
||||
|
||||
startForeground(NOTIFICATION_ID, buildNotification())
|
||||
return START_STICKY
|
||||
}
|
||||
|
||||
override fun onDestroy() {
|
||||
restoreAudioMode()
|
||||
releaseWifiLock()
|
||||
releaseWakeLock()
|
||||
super.onDestroy()
|
||||
}
|
||||
|
||||
override fun onBind(intent: Intent?): IBinder? = null
|
||||
|
||||
// -- Notification ---------------------------------------------------------
|
||||
|
||||
private fun buildNotification(): Notification {
|
||||
// Tapping the notification returns to the call screen
|
||||
val contentIntent = PendingIntent.getActivity(
|
||||
this,
|
||||
0,
|
||||
Intent(this, CallActivity::class.java).apply {
|
||||
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
|
||||
},
|
||||
PendingIntent.FLAG_IMMUTABLE or PendingIntent.FLAG_UPDATE_CURRENT
|
||||
)
|
||||
|
||||
// "End call" action button
|
||||
val stopIntent = PendingIntent.getService(
|
||||
this,
|
||||
1,
|
||||
Intent(this, CallService::class.java).apply { action = ACTION_STOP },
|
||||
PendingIntent.FLAG_IMMUTABLE or PendingIntent.FLAG_UPDATE_CURRENT
|
||||
)
|
||||
|
||||
return NotificationCompat.Builder(this, WzpApplication.CHANNEL_ID)
|
||||
.setContentTitle("WZ Phone")
|
||||
.setContentText("Call in progress")
|
||||
.setSmallIcon(android.R.drawable.ic_menu_call)
|
||||
.setOngoing(true)
|
||||
.setContentIntent(contentIntent)
|
||||
.addAction(android.R.drawable.ic_menu_close_clear_cancel, "End Call", stopIntent)
|
||||
.setCategory(NotificationCompat.CATEGORY_CALL)
|
||||
.setPriority(NotificationCompat.PRIORITY_LOW)
|
||||
.build()
|
||||
}
|
||||
|
||||
// -- Wake lock ------------------------------------------------------------
|
||||
|
||||
private fun acquireWakeLock() {
|
||||
val pm = getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||
wakeLock = pm.newWakeLock(
|
||||
PowerManager.PARTIAL_WAKE_LOCK,
|
||||
"wzp:call_wake_lock"
|
||||
).apply {
|
||||
acquire(MAX_CALL_DURATION_MS)
|
||||
}
|
||||
}
|
||||
|
||||
private fun releaseWakeLock() {
|
||||
wakeLock?.let {
|
||||
if (it.isHeld) it.release()
|
||||
}
|
||||
wakeLock = null
|
||||
}
|
||||
|
||||
// -- Wi-Fi lock -----------------------------------------------------------
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
private fun acquireWifiLock() {
|
||||
val wm = applicationContext.getSystemService(Context.WIFI_SERVICE) as WifiManager
|
||||
wifiLock = wm.createWifiLock(
|
||||
WifiManager.WIFI_MODE_FULL_HIGH_PERF,
|
||||
"wzp:call_wifi_lock"
|
||||
).apply {
|
||||
acquire()
|
||||
}
|
||||
}
|
||||
|
||||
private fun releaseWifiLock() {
|
||||
wifiLock?.let {
|
||||
if (it.isHeld) it.release()
|
||||
}
|
||||
wifiLock = null
|
||||
}
|
||||
|
||||
// -- Audio mode -----------------------------------------------------------
|
||||
|
||||
private fun setAudioMode() {
|
||||
val am = getSystemService(Context.AUDIO_SERVICE) as AudioManager
|
||||
previousAudioMode = am.mode
|
||||
am.mode = AudioManager.MODE_IN_COMMUNICATION
|
||||
}
|
||||
|
||||
private fun restoreAudioMode() {
|
||||
val am = getSystemService(Context.AUDIO_SERVICE) as AudioManager
|
||||
am.mode = previousAudioMode
|
||||
}
|
||||
|
||||
// -- Static helpers -------------------------------------------------------
|
||||
|
||||
companion object {
|
||||
private const val NOTIFICATION_ID = 1001
|
||||
private const val ACTION_STOP = "com.wzp.service.STOP"
|
||||
private const val MAX_CALL_DURATION_MS = 4L * 60 * 60 * 1000 // 4 hours
|
||||
|
||||
/** Called when the user taps "End Call" in the notification. */
|
||||
var onStopFromNotification: (() -> Unit)? = null
|
||||
|
||||
/** Start the foreground call service. */
|
||||
fun start(context: Context) {
|
||||
val intent = Intent(context, CallService::class.java)
|
||||
context.startForegroundService(intent)
|
||||
}
|
||||
|
||||
/** Stop the foreground call service. */
|
||||
fun stop(context: Context) {
|
||||
val intent = Intent(context, CallService::class.java).apply {
|
||||
action = ACTION_STOP
|
||||
}
|
||||
context.startService(intent)
|
||||
}
|
||||
}
|
||||
}
|
||||
149
android/app/src/main/java/com/wzp/ui/call/CallActivity.kt
Normal file
149
android/app/src/main/java/com/wzp/ui/call/CallActivity.kt
Normal file
@@ -0,0 +1,149 @@
|
||||
package com.wzp.ui.call
|
||||
|
||||
import android.Manifest
|
||||
import android.content.Intent
|
||||
import android.content.pm.PackageManager
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import android.widget.Toast
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.activity.viewModels
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.darkColorScheme
|
||||
import androidx.compose.material3.dynamicDarkColorScheme
|
||||
import androidx.compose.material3.dynamicLightColorScheme
|
||||
import androidx.compose.material3.lightColorScheme
|
||||
import androidx.compose.foundation.isSystemInDarkTheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.core.content.ContextCompat
|
||||
import androidx.core.content.FileProvider
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.lifecycle.repeatOnLifecycle
|
||||
import com.wzp.ui.settings.SettingsScreen
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* Main activity hosting the in-call Compose UI.
|
||||
*
|
||||
* Call lifecycle (wake lock, Wi-Fi lock, audio mode, notification)
|
||||
* is managed by [com.wzp.service.CallService] foreground service.
|
||||
*/
|
||||
class CallActivity : ComponentActivity() {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "CallActivity"
|
||||
}
|
||||
|
||||
private val viewModel: CallViewModel by viewModels()
|
||||
|
||||
private val audioPermissionLauncher = registerForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) { granted ->
|
||||
if (!granted) {
|
||||
Toast.makeText(this, "Microphone permission is required for calls", Toast.LENGTH_LONG).show()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
||||
viewModel.setContext(this)
|
||||
|
||||
setContent {
|
||||
WzpTheme {
|
||||
var showSettings by remember { mutableStateOf(false) }
|
||||
if (showSettings) {
|
||||
SettingsScreen(
|
||||
viewModel = viewModel,
|
||||
onBack = { showSettings = false }
|
||||
)
|
||||
} else {
|
||||
InCallScreen(
|
||||
viewModel = viewModel,
|
||||
onHangUp = { viewModel.stopCall() },
|
||||
onOpenSettings = { showSettings = true }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO)
|
||||
!= PackageManager.PERMISSION_GRANTED
|
||||
) {
|
||||
audioPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
|
||||
}
|
||||
|
||||
// Watch for debug zip ready → launch email intent
|
||||
lifecycleScope.launch {
|
||||
repeatOnLifecycle(Lifecycle.State.STARTED) {
|
||||
viewModel.debugZipReady.collect { zipFile ->
|
||||
if (zipFile != null && zipFile.exists()) {
|
||||
Log.i(TAG, "debug zip ready: ${zipFile.absolutePath} (${zipFile.length()} bytes)")
|
||||
launchEmailIntent(zipFile)
|
||||
viewModel.onDebugReportSent()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun launchEmailIntent(zipFile: java.io.File) {
|
||||
try {
|
||||
val authority = "${applicationContext.packageName}.fileprovider"
|
||||
Log.i(TAG, "FileProvider authority: $authority, file: ${zipFile.absolutePath}")
|
||||
val uri = FileProvider.getUriForFile(this, authority, zipFile)
|
||||
Log.i(TAG, "FileProvider URI: $uri")
|
||||
|
||||
val intent = Intent(Intent.ACTION_SEND).apply {
|
||||
type = "message/rfc822"
|
||||
putExtra(Intent.EXTRA_EMAIL, arrayOf("manwefarm@gmail.com"))
|
||||
putExtra(Intent.EXTRA_SUBJECT, "WZ Phone Debug Report - ${zipFile.name}")
|
||||
putExtra(
|
||||
Intent.EXTRA_TEXT,
|
||||
"Debug report attached.\n\nContains: call recordings (WAV), RMS histograms (CSV), logcat, stats."
|
||||
)
|
||||
putExtra(Intent.EXTRA_STREAM, uri)
|
||||
addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
|
||||
}
|
||||
startActivity(Intent.createChooser(intent, "Send debug report"))
|
||||
Log.i(TAG, "email intent launched")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "email intent failed", e)
|
||||
Toast.makeText(this, "Failed to launch email: ${e.message}", Toast.LENGTH_LONG).show()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onDestroy() {
|
||||
super.onDestroy()
|
||||
if (isFinishing) {
|
||||
viewModel.stopCall()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun WzpTheme(content: @Composable () -> Unit) {
|
||||
val darkTheme = isSystemInDarkTheme()
|
||||
val context = LocalContext.current
|
||||
|
||||
val colorScheme = when {
|
||||
android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.S -> {
|
||||
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
|
||||
}
|
||||
darkTheme -> darkColorScheme()
|
||||
else -> lightColorScheme()
|
||||
}
|
||||
|
||||
MaterialTheme(
|
||||
colorScheme = colorScheme,
|
||||
content = content
|
||||
)
|
||||
}
|
||||
445
android/app/src/main/java/com/wzp/ui/call/CallViewModel.kt
Normal file
445
android/app/src/main/java/com/wzp/ui/call/CallViewModel.kt
Normal file
@@ -0,0 +1,445 @@
|
||||
package com.wzp.ui.call
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.wzp.audio.AudioPipeline
|
||||
import com.wzp.audio.AudioRouteManager
|
||||
import com.wzp.data.SettingsRepository
|
||||
import com.wzp.debug.DebugReporter
|
||||
import com.wzp.engine.CallStats
|
||||
import com.wzp.service.CallService
|
||||
import com.wzp.engine.WzpCallback
|
||||
import com.wzp.engine.WzpEngine
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import java.io.File
|
||||
import java.net.Inet4Address
|
||||
import java.net.Inet6Address
|
||||
import java.net.InetAddress
|
||||
|
||||
data class ServerEntry(val address: String, val label: String)
|
||||
|
||||
class CallViewModel : ViewModel(), WzpCallback {
|
||||
|
||||
private var engine: WzpEngine? = null
|
||||
private var engineInitialized = false
|
||||
private var audioPipeline: AudioPipeline? = null
|
||||
private var audioRouteManager: AudioRouteManager? = null
|
||||
private var audioStarted = false
|
||||
private var appContext: Context? = null
|
||||
private var settings: SettingsRepository? = null
|
||||
private var debugReporter: DebugReporter? = null
|
||||
private var lastStatsJson: String = "{}"
|
||||
private var lastCallDuration: Double = 0.0
|
||||
private var lastCallServer: String = ""
|
||||
|
||||
private val _callState = MutableStateFlow(0)
|
||||
val callState: StateFlow<Int> get() = _callState.asStateFlow()
|
||||
|
||||
private val _isMuted = MutableStateFlow(false)
|
||||
val isMuted: StateFlow<Boolean> = _isMuted.asStateFlow()
|
||||
|
||||
private val _isSpeaker = MutableStateFlow(false)
|
||||
val isSpeaker: StateFlow<Boolean> = _isSpeaker.asStateFlow()
|
||||
|
||||
private val _stats = MutableStateFlow(CallStats())
|
||||
val stats: StateFlow<CallStats> = _stats.asStateFlow()
|
||||
|
||||
private val _qualityTier = MutableStateFlow(0)
|
||||
val qualityTier: StateFlow<Int> = _qualityTier.asStateFlow()
|
||||
|
||||
private val _errorMessage = MutableStateFlow<String?>(null)
|
||||
val errorMessage: StateFlow<String?> = _errorMessage.asStateFlow()
|
||||
|
||||
private val _roomName = MutableStateFlow(DEFAULT_ROOM)
|
||||
val roomName: StateFlow<String> = _roomName.asStateFlow()
|
||||
|
||||
private val _selectedServer = MutableStateFlow(0)
|
||||
val selectedServer: StateFlow<Int> = _selectedServer.asStateFlow()
|
||||
|
||||
private val _servers = MutableStateFlow(DEFAULT_SERVERS.toList())
|
||||
val servers: StateFlow<List<ServerEntry>> = _servers.asStateFlow()
|
||||
|
||||
private val _preferIPv6 = MutableStateFlow(false)
|
||||
val preferIPv6: StateFlow<Boolean> = _preferIPv6.asStateFlow()
|
||||
|
||||
private val _playoutGainDb = MutableStateFlow(0f)
|
||||
val playoutGainDb: StateFlow<Float> = _playoutGainDb.asStateFlow()
|
||||
|
||||
private val _captureGainDb = MutableStateFlow(0f)
|
||||
val captureGainDb: StateFlow<Float> = _captureGainDb.asStateFlow()
|
||||
|
||||
private val _alias = MutableStateFlow("")
|
||||
val alias: StateFlow<String> = _alias.asStateFlow()
|
||||
|
||||
private val _seedHex = MutableStateFlow("")
|
||||
val seedHex: StateFlow<String> = _seedHex.asStateFlow()
|
||||
|
||||
private val _aecEnabled = MutableStateFlow(true)
|
||||
val aecEnabled: StateFlow<Boolean> = _aecEnabled.asStateFlow()
|
||||
|
||||
/** True when a call just ended and debug report can be sent. */
|
||||
private val _debugReportAvailable = MutableStateFlow(false)
|
||||
val debugReportAvailable: StateFlow<Boolean> = _debugReportAvailable.asStateFlow()
|
||||
|
||||
/** Status: null=idle, "Preparing..."=in progress, "ready"=zip ready, "Error:..."=failed */
|
||||
private val _debugReportStatus = MutableStateFlow<String?>(null)
|
||||
val debugReportStatus: StateFlow<String?> = _debugReportStatus.asStateFlow()
|
||||
|
||||
/** The zip file ready to be emailed. Set by sendDebugReport, consumed by Activity. */
|
||||
private val _debugZipReady = MutableStateFlow<File?>(null)
|
||||
val debugZipReady: StateFlow<File?> = _debugZipReady.asStateFlow()
|
||||
|
||||
private var statsJob: Job? = null
|
||||
|
||||
companion object {
|
||||
private const val TAG = "WzpCall"
|
||||
val DEFAULT_SERVERS = listOf(
|
||||
ServerEntry("172.16.81.175:4433", "LAN (172.16.81.175)"),
|
||||
ServerEntry("193.180.213.68:4433", "Pangolin (IP)"),
|
||||
)
|
||||
const val DEFAULT_ROOM = "android"
|
||||
}
|
||||
|
||||
fun setContext(context: Context) {
|
||||
val appCtx = context.applicationContext
|
||||
appContext = appCtx
|
||||
if (audioPipeline == null) {
|
||||
audioPipeline = AudioPipeline(appCtx)
|
||||
}
|
||||
if (audioRouteManager == null) {
|
||||
audioRouteManager = AudioRouteManager(appCtx)
|
||||
}
|
||||
if (debugReporter == null) {
|
||||
debugReporter = DebugReporter(appCtx)
|
||||
}
|
||||
if (settings == null) {
|
||||
settings = SettingsRepository(appCtx)
|
||||
loadSettings()
|
||||
}
|
||||
}
|
||||
|
||||
private fun loadSettings() {
|
||||
val s = settings ?: return
|
||||
s.loadServers()?.let { saved ->
|
||||
if (saved.isNotEmpty()) _servers.value = saved
|
||||
}
|
||||
_selectedServer.value = s.loadSelectedServer().coerceIn(0, _servers.value.lastIndex)
|
||||
_roomName.value = s.loadRoom()
|
||||
_alias.value = s.getOrCreateAlias()
|
||||
_preferIPv6.value = s.loadPreferIPv6()
|
||||
_playoutGainDb.value = s.loadPlayoutGain()
|
||||
_captureGainDb.value = s.loadCaptureGain()
|
||||
_seedHex.value = s.getOrCreateSeedHex()
|
||||
_aecEnabled.value = s.loadAecEnabled()
|
||||
}
|
||||
|
||||
fun selectServer(index: Int) {
|
||||
if (index in _servers.value.indices) {
|
||||
_selectedServer.value = index
|
||||
settings?.saveSelectedServer(index)
|
||||
}
|
||||
}
|
||||
|
||||
fun setPreferIPv6(prefer: Boolean) {
|
||||
_preferIPv6.value = prefer
|
||||
settings?.savePreferIPv6(prefer)
|
||||
}
|
||||
|
||||
fun addServer(hostPort: String, label: String) {
|
||||
val current = _servers.value.toMutableList()
|
||||
current.add(ServerEntry(hostPort, label))
|
||||
_servers.value = current
|
||||
settings?.saveServers(current)
|
||||
}
|
||||
|
||||
fun removeServer(index: Int) {
|
||||
if (index < DEFAULT_SERVERS.size) return // don't remove built-in servers
|
||||
val current = _servers.value.toMutableList()
|
||||
if (index in current.indices) {
|
||||
current.removeAt(index)
|
||||
_servers.value = current
|
||||
if (_selectedServer.value >= current.size) {
|
||||
_selectedServer.value = 0
|
||||
}
|
||||
settings?.saveServers(current)
|
||||
settings?.saveSelectedServer(_selectedServer.value)
|
||||
}
|
||||
}
|
||||
|
||||
/** Batch-apply servers and selection from Settings draft state. */
|
||||
fun applyServers(servers: List<ServerEntry>, selected: Int) {
|
||||
_servers.value = servers
|
||||
_selectedServer.value = selected.coerceIn(0, servers.lastIndex)
|
||||
settings?.saveServers(servers)
|
||||
settings?.saveSelectedServer(_selectedServer.value)
|
||||
}
|
||||
|
||||
fun setRoomName(name: String) {
|
||||
_roomName.value = name
|
||||
settings?.saveRoom(name)
|
||||
}
|
||||
|
||||
fun setPlayoutGainDb(db: Float) {
|
||||
_playoutGainDb.value = db
|
||||
audioPipeline?.playoutGainDb = db
|
||||
settings?.savePlayoutGain(db)
|
||||
}
|
||||
|
||||
fun setCaptureGainDb(db: Float) {
|
||||
_captureGainDb.value = db
|
||||
audioPipeline?.captureGainDb = db
|
||||
settings?.saveCaptureGain(db)
|
||||
}
|
||||
|
||||
fun setAlias(alias: String) {
|
||||
_alias.value = alias
|
||||
settings?.saveAlias(alias)
|
||||
}
|
||||
|
||||
fun restoreSeed(hex: String) {
|
||||
_seedHex.value = hex
|
||||
settings?.saveSeedHex(hex)
|
||||
}
|
||||
|
||||
fun setAecEnabled(enabled: Boolean) {
|
||||
_aecEnabled.value = enabled
|
||||
settings?.saveAecEnabled(enabled)
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve DNS hostname to IP address on the Kotlin/Android side,
|
||||
* since Rust's DNS resolution may not work on Android.
|
||||
* Returns "ip:port" string.
|
||||
*/
|
||||
private fun resolveToIp(hostPort: String): String {
|
||||
val parts = hostPort.split(":")
|
||||
if (parts.size != 2) return hostPort
|
||||
val host = parts[0]
|
||||
val port = parts[1]
|
||||
|
||||
// Already an IP address — return as-is
|
||||
if (host.matches(Regex("""\d+\.\d+\.\d+\.\d+"""))) return hostPort
|
||||
if (host.contains(":")) return hostPort // IPv6 literal
|
||||
|
||||
return try {
|
||||
val addresses = InetAddress.getAllByName(host)
|
||||
val preferV6 = _preferIPv6.value
|
||||
val picked = if (preferV6) {
|
||||
addresses.firstOrNull { it is Inet6Address } ?: addresses.firstOrNull { it is Inet4Address }
|
||||
} else {
|
||||
addresses.firstOrNull { it is Inet4Address } ?: addresses.firstOrNull { it is Inet6Address }
|
||||
}
|
||||
if (picked != null) {
|
||||
val ip = picked.hostAddress ?: host
|
||||
val formatted = if (picked is Inet6Address) "[$ip]:$port" else "$ip:$port"
|
||||
formatted
|
||||
} else {
|
||||
hostPort
|
||||
}
|
||||
} catch (_: Exception) {
|
||||
hostPort // resolution failed — pass through and let Rust try
|
||||
}
|
||||
}
|
||||
|
||||
/** Tear down engine and audio. Pass stopService=true to also stop the foreground service. */
|
||||
private fun teardown(stopService: Boolean = true) {
|
||||
Log.i(TAG, "teardown: stopping audio, stopService=$stopService")
|
||||
val hadCall = audioStarted
|
||||
CallService.onStopFromNotification = null
|
||||
stopAudio()
|
||||
stopStatsPolling()
|
||||
Log.i(TAG, "teardown: stopping engine")
|
||||
try { engine?.stopCall() } catch (e: Exception) { Log.w(TAG, "stopCall err: $e") }
|
||||
try { engine?.destroy() } catch (e: Exception) { Log.w(TAG, "destroy err: $e") }
|
||||
engine = null
|
||||
engineInitialized = false
|
||||
_callState.value = 0
|
||||
if (hadCall) {
|
||||
_debugReportAvailable.value = true
|
||||
}
|
||||
if (stopService) {
|
||||
try { appContext?.let { CallService.stop(it) } } catch (_: Exception) {}
|
||||
}
|
||||
Log.i(TAG, "teardown: done")
|
||||
}
|
||||
|
||||
fun startCall() {
|
||||
val serverEntry = _servers.value[_selectedServer.value]
|
||||
val room = _roomName.value
|
||||
Log.i(TAG, "startCall: server=${serverEntry.address} room=$room")
|
||||
_debugReportAvailable.value = false
|
||||
_debugReportStatus.value = null
|
||||
lastCallServer = serverEntry.address
|
||||
debugReporter?.prepareForCall()
|
||||
try {
|
||||
// Teardown previous call but don't stop the service (we're about to restart it)
|
||||
teardown(stopService = false)
|
||||
|
||||
Log.i(TAG, "startCall: creating engine")
|
||||
engine = WzpEngine(this)
|
||||
engine!!.init()
|
||||
engineInitialized = true
|
||||
_callState.value = 1
|
||||
_errorMessage.value = null
|
||||
try { appContext?.let { CallService.start(it) } } catch (e: Exception) {
|
||||
Log.w(TAG, "service start err: $e")
|
||||
}
|
||||
startStatsPolling()
|
||||
|
||||
viewModelScope.launch(kotlinx.coroutines.Dispatchers.IO) {
|
||||
try {
|
||||
val relay = resolveToIp(serverEntry.address)
|
||||
val seed = _seedHex.value
|
||||
val name = _alias.value
|
||||
Log.i(TAG, "startCall: resolved=$relay, alias=$name, calling engine.startCall")
|
||||
val result = engine?.startCall(relay, room, seedHex = seed, alias = name) ?: -1
|
||||
Log.i(TAG, "startCall: engine returned $result")
|
||||
// Only wire up notification callback after engine is running
|
||||
CallService.onStopFromNotification = { stopCall() }
|
||||
if (result != 0) {
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Failed to start call (code $result)"
|
||||
appContext?.let { CallService.stop(it) }
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "startCall IO error", e)
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Engine error: ${e.message}"
|
||||
appContext?.let { CallService.stop(it) }
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "startCall error", e)
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Engine error: ${e.message}"
|
||||
appContext?.let { CallService.stop(it) }
|
||||
}
|
||||
}
|
||||
|
||||
fun stopCall() {
|
||||
Log.i(TAG, "stopCall")
|
||||
teardown()
|
||||
}
|
||||
|
||||
fun toggleMute() {
|
||||
val newMuted = !_isMuted.value
|
||||
_isMuted.value = newMuted
|
||||
try { engine?.setMute(newMuted) } catch (_: Exception) {}
|
||||
}
|
||||
|
||||
fun toggleSpeaker() {
|
||||
val newSpeaker = !_isSpeaker.value
|
||||
_isSpeaker.value = newSpeaker
|
||||
audioRouteManager?.setSpeaker(newSpeaker)
|
||||
}
|
||||
|
||||
fun clearError() { _errorMessage.value = null }
|
||||
|
||||
fun sendDebugReport() {
|
||||
val reporter = debugReporter ?: return
|
||||
_debugReportStatus.value = "Preparing debug report..."
|
||||
viewModelScope.launch(kotlinx.coroutines.Dispatchers.IO) {
|
||||
val zipFile = reporter.collectZip(
|
||||
callDurationSecs = lastCallDuration,
|
||||
finalStatsJson = lastStatsJson,
|
||||
aecEnabled = _aecEnabled.value,
|
||||
alias = _alias.value,
|
||||
server = lastCallServer,
|
||||
room = _roomName.value
|
||||
)
|
||||
if (zipFile != null) {
|
||||
_debugZipReady.value = zipFile
|
||||
_debugReportStatus.value = "ready"
|
||||
} else {
|
||||
_debugReportStatus.value = "Error: failed to create zip"
|
||||
}
|
||||
_debugReportAvailable.value = false
|
||||
}
|
||||
}
|
||||
|
||||
/** Called by Activity after email intent is launched. */
|
||||
fun onDebugReportSent() {
|
||||
_debugZipReady.value = null
|
||||
_debugReportStatus.value = null
|
||||
}
|
||||
|
||||
fun dismissDebugReport() {
|
||||
_debugReportAvailable.value = false
|
||||
_debugReportStatus.value = null
|
||||
_debugZipReady.value = null
|
||||
}
|
||||
|
||||
// WzpCallback
|
||||
override fun onCallStateChanged(state: Int) { _callState.value = state }
|
||||
override fun onQualityTierChanged(tier: Int) { _qualityTier.value = tier }
|
||||
override fun onError(code: Int, message: String) { _errorMessage.value = "Error $code: $message" }
|
||||
|
||||
private fun startAudio() {
|
||||
if (audioStarted) return
|
||||
val e = engine ?: return
|
||||
val ctx = appContext ?: return
|
||||
// Create a fresh pipeline each call to avoid stale threads
|
||||
audioPipeline = AudioPipeline(ctx).also {
|
||||
it.playoutGainDb = _playoutGainDb.value
|
||||
it.captureGainDb = _captureGainDb.value
|
||||
it.aecEnabled = _aecEnabled.value
|
||||
it.start(e)
|
||||
}
|
||||
audioRouteManager?.register()
|
||||
audioStarted = true
|
||||
}
|
||||
|
||||
private fun stopAudio() {
|
||||
if (!audioStarted) return
|
||||
audioPipeline?.stop()
|
||||
audioPipeline = null
|
||||
audioRouteManager?.unregister()
|
||||
audioRouteManager?.setSpeaker(false)
|
||||
_isSpeaker.value = false
|
||||
audioStarted = false
|
||||
}
|
||||
|
||||
private fun startStatsPolling() {
|
||||
statsJob?.cancel()
|
||||
statsJob = viewModelScope.launch {
|
||||
while (isActive) {
|
||||
try {
|
||||
val json = engine?.getStats() ?: "{}"
|
||||
if (json.isNotEmpty()) {
|
||||
Log.d(TAG, "raw: $json")
|
||||
lastStatsJson = json
|
||||
val s = CallStats.fromJson(json)
|
||||
lastCallDuration = s.durationSecs
|
||||
_stats.value = s
|
||||
if (s.state != 0) {
|
||||
_callState.value = s.state
|
||||
}
|
||||
if (s.state == 2 && !audioStarted) {
|
||||
startAudio()
|
||||
}
|
||||
}
|
||||
} catch (_: Exception) {}
|
||||
delay(500L)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopStatsPolling() {
|
||||
statsJob?.cancel()
|
||||
statsJob = null
|
||||
}
|
||||
|
||||
override fun onCleared() {
|
||||
super.onCleared()
|
||||
Log.i(TAG, "onCleared")
|
||||
teardown()
|
||||
}
|
||||
}
|
||||
688
android/app/src/main/java/com/wzp/ui/call/InCallScreen.kt
Normal file
688
android/app/src/main/java/com/wzp/ui/call/InCallScreen.kt
Normal file
@@ -0,0 +1,688 @@
|
||||
package com.wzp.ui.call
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.ExperimentalLayoutApi
|
||||
import androidx.compose.foundation.layout.FlowRow
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.FilledIconButton
|
||||
import androidx.compose.material3.FilledTonalIconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
import androidx.compose.material3.OutlinedTextField
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import com.wzp.engine.CallStats
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
@OptIn(ExperimentalLayoutApi::class)
|
||||
@Composable
|
||||
fun InCallScreen(
|
||||
viewModel: CallViewModel,
|
||||
onHangUp: () -> Unit,
|
||||
onOpenSettings: () -> Unit = {}
|
||||
) {
|
||||
val callState by viewModel.callState.collectAsState()
|
||||
val isMuted by viewModel.isMuted.collectAsState()
|
||||
val isSpeaker by viewModel.isSpeaker.collectAsState()
|
||||
val stats by viewModel.stats.collectAsState()
|
||||
val qualityTier by viewModel.qualityTier.collectAsState()
|
||||
val errorMessage by viewModel.errorMessage.collectAsState()
|
||||
val roomName by viewModel.roomName.collectAsState()
|
||||
val selectedServer by viewModel.selectedServer.collectAsState()
|
||||
val servers by viewModel.servers.collectAsState()
|
||||
val preferIPv6 by viewModel.preferIPv6.collectAsState()
|
||||
val playoutGainDb by viewModel.playoutGainDb.collectAsState()
|
||||
val captureGainDb by viewModel.captureGainDb.collectAsState()
|
||||
val debugReportAvailable by viewModel.debugReportAvailable.collectAsState()
|
||||
val debugReportStatus by viewModel.debugReportStatus.collectAsState()
|
||||
|
||||
var showAddServerDialog by remember { mutableStateOf(false) }
|
||||
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp)
|
||||
.verticalScroll(rememberScrollState()),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
// Settings button (top-right)
|
||||
if (callState == 0) {
|
||||
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||
TextButton(onClick = onOpenSettings) {
|
||||
Text("Settings")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(if (callState == 0) 16.dp else 48.dp))
|
||||
|
||||
Text(
|
||||
text = "WZ Phone",
|
||||
style = MaterialTheme.typography.headlineMedium.copy(
|
||||
fontWeight = FontWeight.Bold
|
||||
),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
CallStateLabel(callState)
|
||||
|
||||
if (callState == 0) {
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
// Server selector
|
||||
Text(
|
||||
text = "Server",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
FlowRow(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.Center
|
||||
) {
|
||||
servers.forEachIndexed { idx, entry ->
|
||||
val isSelected = selectedServer == idx
|
||||
FilledTonalIconButton(
|
||||
onClick = { viewModel.selectServer(idx) },
|
||||
modifier = Modifier
|
||||
.padding(2.dp)
|
||||
.height(36.dp)
|
||||
.width(140.dp),
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = if (isSelected) {
|
||||
IconButtonDefaults.filledTonalIconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
} else {
|
||||
IconButtonDefaults.filledTonalIconButtonColors()
|
||||
}
|
||||
) {
|
||||
Text(
|
||||
text = entry.label,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
maxLines = 1
|
||||
)
|
||||
}
|
||||
}
|
||||
// + Add button
|
||||
OutlinedButton(
|
||||
onClick = { showAddServerDialog = true },
|
||||
modifier = Modifier
|
||||
.padding(2.dp)
|
||||
.height(36.dp),
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
) {
|
||||
Text("+", style = MaterialTheme.typography.labelMedium)
|
||||
}
|
||||
}
|
||||
|
||||
// IPv4/IPv6 preference
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.Center
|
||||
) {
|
||||
Text(
|
||||
text = "IPv4",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = if (!preferIPv6) MaterialTheme.colorScheme.primary
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Switch(
|
||||
checked = preferIPv6,
|
||||
onCheckedChange = { viewModel.setPreferIPv6(it) },
|
||||
modifier = Modifier.padding(horizontal = 8.dp)
|
||||
)
|
||||
Text(
|
||||
text = "IPv6",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = if (preferIPv6) MaterialTheme.colorScheme.primary
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
|
||||
// Selected server address
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Text(
|
||||
text = servers.getOrNull(selectedServer)?.address ?: "",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = roomName,
|
||||
onValueChange = { viewModel.setRoomName(it) },
|
||||
label = { Text("Room") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth(0.6f)
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
Button(
|
||||
onClick = { viewModel.startCall() },
|
||||
modifier = Modifier
|
||||
.size(120.dp)
|
||||
.clip(CircleShape),
|
||||
shape = CircleShape,
|
||||
colors = ButtonDefaults.buttonColors(
|
||||
containerColor = Color(0xFF4CAF50)
|
||||
)
|
||||
) {
|
||||
Text(
|
||||
text = "CALL",
|
||||
style = MaterialTheme.typography.titleLarge.copy(
|
||||
fontWeight = FontWeight.Bold
|
||||
),
|
||||
color = Color.White
|
||||
)
|
||||
}
|
||||
|
||||
errorMessage?.let { err ->
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
Text(
|
||||
text = err,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
}
|
||||
|
||||
// Debug report card — shown after call ends
|
||||
if (debugReportAvailable || debugReportStatus != null) {
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
DebugReportCard(
|
||||
available = debugReportAvailable,
|
||||
status = debugReportStatus,
|
||||
onSend = { viewModel.sendDebugReport() },
|
||||
onDismiss = { viewModel.dismissDebugReport() }
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// In-call UI
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
DurationDisplay(stats.durationSecs)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
QualityIndicator(qualityTier, stats.qualityLabel)
|
||||
|
||||
if (stats.roomParticipantCount > 0) {
|
||||
// Dedup by fingerprint — same key = same person, even if
|
||||
// relay hasn't cleaned up stale entries yet.
|
||||
val unique = stats.roomParticipants
|
||||
.distinctBy { it.fingerprint.ifEmpty { it.displayName } }
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(
|
||||
text = "${unique.size} in room",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
unique.forEach { member ->
|
||||
Text(
|
||||
text = member.displayName,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
AudioLevelBar(stats.audioLevel)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Gain sliders
|
||||
GainSlider(
|
||||
label = "Voice Volume",
|
||||
gainDb = playoutGainDb,
|
||||
onGainChange = { viewModel.setPlayoutGainDb(it) }
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
GainSlider(
|
||||
label = "Mic Gain",
|
||||
gainDb = captureGainDb,
|
||||
onGainChange = { viewModel.setCaptureGainDb(it) }
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
ControlRow(
|
||||
isMuted = isMuted,
|
||||
isSpeaker = isSpeaker,
|
||||
onToggleMute = viewModel::toggleMute,
|
||||
onToggleSpeaker = viewModel::toggleSpeaker,
|
||||
onHangUp = {
|
||||
viewModel.stopCall()
|
||||
}
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
StatsOverlay(stats)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showAddServerDialog) {
|
||||
AddServerDialog(
|
||||
onDismiss = { showAddServerDialog = false },
|
||||
onAdd = { host, port, label ->
|
||||
viewModel.addServer("$host:$port", label)
|
||||
showAddServerDialog = false
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun AddServerDialog(
|
||||
onDismiss: () -> Unit,
|
||||
onAdd: (host: String, port: String, label: String) -> Unit
|
||||
) {
|
||||
var host by remember { mutableStateOf("") }
|
||||
var port by remember { mutableStateOf("4433") }
|
||||
var label by remember { mutableStateOf("") }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Add Server") },
|
||||
text = {
|
||||
Column {
|
||||
OutlinedTextField(
|
||||
value = host,
|
||||
onValueChange = { host = it },
|
||||
label = { Text("Host (IP or domain)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = port,
|
||||
onValueChange = { port = it },
|
||||
label = { Text("Port") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = label,
|
||||
onValueChange = { label = it },
|
||||
label = { Text("Label (optional)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
if (host.isNotBlank()) {
|
||||
val displayLabel = label.ifBlank { host }
|
||||
onAdd(host.trim(), port.trim(), displayLabel)
|
||||
}
|
||||
}
|
||||
) { Text("Add") }
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) { Text("Cancel") }
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun CallStateLabel(state: Int) {
|
||||
val label = when (state) {
|
||||
0 -> "Ready to connect"
|
||||
1 -> "Connecting..."
|
||||
2 -> "Active"
|
||||
3 -> "Reconnecting..."
|
||||
4 -> "Call Ended"
|
||||
else -> "Unknown"
|
||||
}
|
||||
val color = when (state) {
|
||||
2 -> Color(0xFF4CAF50)
|
||||
1, 3 -> Color(0xFFFFC107)
|
||||
else -> MaterialTheme.colorScheme.onSurfaceVariant
|
||||
}
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
color = color
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun DurationDisplay(durationSecs: Double) {
|
||||
val totalSeconds = durationSecs.roundToInt()
|
||||
val minutes = totalSeconds / 60
|
||||
val seconds = totalSeconds % 60
|
||||
Text(
|
||||
text = "%02d:%02d".format(minutes, seconds),
|
||||
style = MaterialTheme.typography.displayLarge.copy(
|
||||
fontWeight = FontWeight.Light,
|
||||
letterSpacing = 4.sp
|
||||
),
|
||||
color = MaterialTheme.colorScheme.onBackground
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun QualityIndicator(tier: Int, label: String) {
|
||||
val dotColor = when (tier) {
|
||||
0 -> Color(0xFF4CAF50)
|
||||
1 -> Color(0xFFFFC107)
|
||||
2 -> Color(0xFFF44336)
|
||||
else -> Color.Gray
|
||||
}
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.Center
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.size(12.dp)
|
||||
.clip(CircleShape)
|
||||
.background(dotColor)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun AudioLevelBar(audioLevel: Int) {
|
||||
val level = if (audioLevel > 0) {
|
||||
(audioLevel.toFloat() / 8000f).coerceIn(0.02f, 1f)
|
||||
} else {
|
||||
0f
|
||||
}
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(
|
||||
text = "Audio Level",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth(0.6f)
|
||||
.height(6.dp)
|
||||
.clip(RoundedCornerShape(3.dp))
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant)
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth(level)
|
||||
.height(6.dp)
|
||||
.background(MaterialTheme.colorScheme.primary)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun GainSlider(label: String, gainDb: Float, onGainChange: (Float) -> Unit) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxWidth(0.8f),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
val sign = if (gainDb >= 0) "+" else ""
|
||||
Text(
|
||||
text = "$label: ${sign}${"%.0f".format(gainDb)} dB",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Slider(
|
||||
value = gainDb,
|
||||
onValueChange = { onGainChange(Math.round(it).toFloat()) },
|
||||
valueRange = -20f..20f,
|
||||
steps = 0,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun ControlRow(
|
||||
isMuted: Boolean,
|
||||
isSpeaker: Boolean,
|
||||
onToggleMute: () -> Unit,
|
||||
onToggleSpeaker: () -> Unit,
|
||||
onHangUp: () -> Unit
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
FilledTonalIconButton(
|
||||
onClick = onToggleMute,
|
||||
modifier = Modifier.size(56.dp),
|
||||
colors = if (isMuted) {
|
||||
IconButtonDefaults.filledTonalIconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
} else {
|
||||
IconButtonDefaults.filledTonalIconButtonColors()
|
||||
}
|
||||
) {
|
||||
Text(
|
||||
text = if (isMuted) "MIC\nOFF" else "MIC",
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
lineHeight = 12.sp
|
||||
)
|
||||
}
|
||||
|
||||
FilledIconButton(
|
||||
onClick = onHangUp,
|
||||
modifier = Modifier.size(72.dp),
|
||||
shape = CircleShape,
|
||||
colors = IconButtonDefaults.filledIconButtonColors(
|
||||
containerColor = Color(0xFFF44336),
|
||||
contentColor = Color.White
|
||||
)
|
||||
) {
|
||||
Text(
|
||||
text = "END",
|
||||
style = MaterialTheme.typography.titleMedium.copy(
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
FilledTonalIconButton(
|
||||
onClick = onToggleSpeaker,
|
||||
modifier = Modifier.size(56.dp),
|
||||
colors = if (isSpeaker) {
|
||||
IconButtonDefaults.filledTonalIconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
} else {
|
||||
IconButtonDefaults.filledTonalIconButtonColors()
|
||||
}
|
||||
) {
|
||||
Text(
|
||||
text = if (isSpeaker) "SPK\nON" else "SPK",
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
lineHeight = 12.sp
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun StatsOverlay(stats: CallStats) {
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f),
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(12.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = "Stats",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
StatItem("Loss", "%.1f%%".format(stats.lossPct))
|
||||
StatItem("RTT", "${stats.rttMs}ms")
|
||||
StatItem("Jitter", "${stats.jitterMs}ms")
|
||||
}
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
StatItem("Sent", "${stats.framesEncoded}")
|
||||
StatItem("Recv", "${stats.framesDecoded}")
|
||||
StatItem("FEC", "${stats.fecRecovered}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun StatItem(label: String, value: String) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(
|
||||
text = value,
|
||||
style = MaterialTheme.typography.bodySmall.copy(fontWeight = FontWeight.Medium),
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun DebugReportCard(
|
||||
available: Boolean,
|
||||
status: String?,
|
||||
onSend: () -> Unit,
|
||||
onDismiss: () -> Unit
|
||||
) {
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.7f),
|
||||
shape = RoundedCornerShape(12.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = "Debug Report",
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold),
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Text(
|
||||
text = "Email call recordings, logs & stats for analysis",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
textAlign = TextAlign.Center
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
when {
|
||||
status != null && status.startsWith("Error") -> {
|
||||
Text(
|
||||
text = status,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) {
|
||||
OutlinedButton(onClick = onSend) { Text("Retry") }
|
||||
TextButton(onClick = onDismiss) { Text("Dismiss") }
|
||||
}
|
||||
}
|
||||
status != null && status != "ready" -> {
|
||||
// Preparing zip...
|
||||
Text(
|
||||
text = status,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
available -> {
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) {
|
||||
Button(onClick = onSend) {
|
||||
Text("Email Report")
|
||||
}
|
||||
TextButton(onClick = onDismiss) {
|
||||
Text("Skip")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
510
android/app/src/main/java/com/wzp/ui/settings/SettingsScreen.kt
Normal file
510
android/app/src/main/java/com/wzp/ui/settings/SettingsScreen.kt
Normal file
@@ -0,0 +1,510 @@
|
||||
package com.wzp.ui.settings
|
||||
|
||||
import android.content.ClipData
|
||||
import android.content.ClipboardManager
|
||||
import android.content.Context
|
||||
import android.widget.Toast
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.ExperimentalLayoutApi
|
||||
import androidx.compose.foundation.layout.FlowRow
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.Divider
|
||||
import androidx.compose.material3.FilledTonalButton
|
||||
import androidx.compose.material3.FilledTonalIconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
import androidx.compose.material3.OutlinedTextField
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.toMutableStateList
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.wzp.ui.call.CallViewModel
|
||||
import com.wzp.ui.call.ServerEntry
|
||||
|
||||
@OptIn(ExperimentalLayoutApi::class)
|
||||
@Composable
|
||||
fun SettingsScreen(
|
||||
viewModel: CallViewModel,
|
||||
onBack: () -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
|
||||
// Snapshot current values into local draft state
|
||||
val currentAlias by viewModel.alias.collectAsState()
|
||||
val currentSeedHex by viewModel.seedHex.collectAsState()
|
||||
val currentServers by viewModel.servers.collectAsState()
|
||||
val currentSelectedServer by viewModel.selectedServer.collectAsState()
|
||||
val currentRoomName by viewModel.roomName.collectAsState()
|
||||
val currentPreferIPv6 by viewModel.preferIPv6.collectAsState()
|
||||
val currentPlayoutGain by viewModel.playoutGainDb.collectAsState()
|
||||
val currentCaptureGain by viewModel.captureGainDb.collectAsState()
|
||||
val currentAecEnabled by viewModel.aecEnabled.collectAsState()
|
||||
|
||||
// Draft state — initialized from current values
|
||||
var draftAlias by remember { mutableStateOf(currentAlias) }
|
||||
var draftSeedHex by remember { mutableStateOf(currentSeedHex) }
|
||||
val draftServers = remember { currentServers.toMutableStateList() }
|
||||
var draftSelectedServer by remember { mutableIntStateOf(currentSelectedServer) }
|
||||
var draftRoomName by remember { mutableStateOf(currentRoomName) }
|
||||
var draftPreferIPv6 by remember { mutableStateOf(currentPreferIPv6) }
|
||||
var draftPlayoutGain by remember { mutableFloatStateOf(currentPlayoutGain) }
|
||||
var draftCaptureGain by remember { mutableFloatStateOf(currentCaptureGain) }
|
||||
var draftAecEnabled by remember { mutableStateOf(currentAecEnabled) }
|
||||
|
||||
// Track if anything changed
|
||||
val hasChanges = draftAlias != currentAlias ||
|
||||
draftSeedHex != currentSeedHex ||
|
||||
draftServers.toList() != currentServers ||
|
||||
draftSelectedServer != currentSelectedServer ||
|
||||
draftRoomName != currentRoomName ||
|
||||
draftPreferIPv6 != currentPreferIPv6 ||
|
||||
draftPlayoutGain != currentPlayoutGain ||
|
||||
draftCaptureGain != currentCaptureGain ||
|
||||
draftAecEnabled != currentAecEnabled
|
||||
|
||||
var showAddServerDialog by remember { mutableStateOf(false) }
|
||||
var showRestoreKeyDialog by remember { mutableStateOf(false) }
|
||||
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp)
|
||||
.verticalScroll(rememberScrollState())
|
||||
) {
|
||||
// Header
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
TextButton(onClick = onBack) {
|
||||
Text("< Back")
|
||||
}
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
Text(
|
||||
text = "Settings",
|
||||
style = MaterialTheme.typography.headlineSmall.copy(
|
||||
fontWeight = FontWeight.Bold
|
||||
),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
// Save button — only enabled when changes exist
|
||||
Button(
|
||||
onClick = {
|
||||
viewModel.setAlias(draftAlias)
|
||||
if (draftSeedHex != currentSeedHex) viewModel.restoreSeed(draftSeedHex)
|
||||
viewModel.applyServers(draftServers.toList(), draftSelectedServer)
|
||||
viewModel.setRoomName(draftRoomName)
|
||||
viewModel.setPreferIPv6(draftPreferIPv6)
|
||||
viewModel.setPlayoutGainDb(draftPlayoutGain)
|
||||
viewModel.setCaptureGainDb(draftCaptureGain)
|
||||
viewModel.setAecEnabled(draftAecEnabled)
|
||||
Toast.makeText(context, "Settings saved", Toast.LENGTH_SHORT).show()
|
||||
onBack()
|
||||
},
|
||||
enabled = hasChanges
|
||||
) {
|
||||
Text("Save")
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
// --- Identity ---
|
||||
SectionHeader("Identity")
|
||||
|
||||
OutlinedTextField(
|
||||
value = draftAlias,
|
||||
onValueChange = { draftAlias = it },
|
||||
label = { Text("Display Name") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Fingerprint display
|
||||
val fingerprint = if (draftSeedHex.length >= 16) draftSeedHex.take(16).uppercase() else "Not generated"
|
||||
Text(
|
||||
text = "Fingerprint",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Text(
|
||||
text = fingerprint.chunked(4).joinToString(" "),
|
||||
style = MaterialTheme.typography.bodyMedium.copy(
|
||||
fontFamily = FontFamily.Monospace
|
||||
),
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
// Key backup/restore
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) {
|
||||
FilledTonalButton(onClick = {
|
||||
val clipboard = context.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager
|
||||
clipboard.setPrimaryClip(ClipData.newPlainText("WZP Key", draftSeedHex))
|
||||
Toast.makeText(context, "Key copied to clipboard", Toast.LENGTH_SHORT).show()
|
||||
}) {
|
||||
Text("Copy Key")
|
||||
}
|
||||
OutlinedButton(onClick = { showRestoreKeyDialog = true }) {
|
||||
Text("Restore Key")
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Audio ---
|
||||
SectionHeader("Audio Defaults")
|
||||
|
||||
GainSlider(
|
||||
label = "Voice Volume",
|
||||
gainDb = draftPlayoutGain,
|
||||
onGainChange = { draftPlayoutGain = Math.round(it).toFloat() }
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
GainSlider(
|
||||
label = "Mic Gain",
|
||||
gainDb = draftCaptureGain,
|
||||
onGainChange = { draftCaptureGain = Math.round(it).toFloat() }
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(
|
||||
text = "Echo Cancellation (AEC)",
|
||||
style = MaterialTheme.typography.bodyMedium
|
||||
)
|
||||
Text(
|
||||
text = "Disable if audio sounds distorted",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
Switch(
|
||||
checked = draftAecEnabled,
|
||||
onCheckedChange = { draftAecEnabled = it }
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Servers ---
|
||||
SectionHeader("Servers")
|
||||
|
||||
FlowRow(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.Start,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
draftServers.forEachIndexed { idx, entry ->
|
||||
val isSelected = draftSelectedServer == idx
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
FilledTonalIconButton(
|
||||
onClick = { draftSelectedServer = idx },
|
||||
modifier = Modifier
|
||||
.padding(end = 2.dp)
|
||||
.height(36.dp)
|
||||
.width(140.dp),
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = if (isSelected) {
|
||||
IconButtonDefaults.filledTonalIconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
} else {
|
||||
IconButtonDefaults.filledTonalIconButtonColors()
|
||||
}
|
||||
) {
|
||||
Text(
|
||||
text = entry.label,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
maxLines = 1
|
||||
)
|
||||
}
|
||||
// Show remove button for non-default servers
|
||||
if (idx >= 2) {
|
||||
TextButton(
|
||||
onClick = {
|
||||
draftServers.removeAt(idx)
|
||||
if (draftSelectedServer >= draftServers.size) {
|
||||
draftSelectedServer = 0
|
||||
}
|
||||
},
|
||||
modifier = Modifier.height(36.dp)
|
||||
) {
|
||||
Text("X", color = MaterialTheme.colorScheme.error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedButton(
|
||||
onClick = { showAddServerDialog = true },
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
) {
|
||||
Text("+ Add Server")
|
||||
}
|
||||
|
||||
// Show selected server address
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(
|
||||
text = "Default: ${draftServers.getOrNull(draftSelectedServer)?.address ?: "none"}",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Network ---
|
||||
SectionHeader("Network")
|
||||
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Text(
|
||||
text = "Prefer IPv6",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
Switch(
|
||||
checked = draftPreferIPv6,
|
||||
onCheckedChange = { draftPreferIPv6 = it }
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Room ---
|
||||
SectionHeader("Room")
|
||||
|
||||
OutlinedTextField(
|
||||
value = draftRoomName,
|
||||
onValueChange = { draftRoomName = it },
|
||||
label = { Text("Default Room") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
}
|
||||
}
|
||||
|
||||
if (showAddServerDialog) {
|
||||
AddServerDialog(
|
||||
onDismiss = { showAddServerDialog = false },
|
||||
onAdd = { host, port, label ->
|
||||
draftServers.add(ServerEntry("$host:$port", label))
|
||||
showAddServerDialog = false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
if (showRestoreKeyDialog) {
|
||||
RestoreKeyDialog(
|
||||
onDismiss = { showRestoreKeyDialog = false },
|
||||
onRestore = { hex ->
|
||||
draftSeedHex = hex
|
||||
showRestoreKeyDialog = false
|
||||
Toast.makeText(context, "Key staged — press Save to apply", Toast.LENGTH_SHORT).show()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun SectionHeader(title: String) {
|
||||
Text(
|
||||
text = title,
|
||||
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.Bold),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun GainSlider(label: String, gainDb: Float, onGainChange: (Float) -> Unit) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
val sign = if (gainDb >= 0) "+" else ""
|
||||
Text(
|
||||
text = "$label: ${sign}${"%.0f".format(gainDb)} dB",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Slider(
|
||||
value = gainDb,
|
||||
onValueChange = onGainChange,
|
||||
valueRange = -20f..20f,
|
||||
steps = 0,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun AddServerDialog(
|
||||
onDismiss: () -> Unit,
|
||||
onAdd: (host: String, port: String, label: String) -> Unit
|
||||
) {
|
||||
var host by remember { mutableStateOf("") }
|
||||
var port by remember { mutableStateOf("4433") }
|
||||
var label by remember { mutableStateOf("") }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Add Server") },
|
||||
text = {
|
||||
Column {
|
||||
OutlinedTextField(
|
||||
value = host,
|
||||
onValueChange = { host = it },
|
||||
label = { Text("Host (IP or domain)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = port,
|
||||
onValueChange = { port = it },
|
||||
label = { Text("Port") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = label,
|
||||
onValueChange = { label = it },
|
||||
label = { Text("Label (optional)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
if (host.isNotBlank()) {
|
||||
val displayLabel = label.ifBlank { host }
|
||||
onAdd(host.trim(), port.trim(), displayLabel)
|
||||
}
|
||||
}
|
||||
) { Text("Add") }
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) { Text("Cancel") }
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun RestoreKeyDialog(
|
||||
onDismiss: () -> Unit,
|
||||
onRestore: (hex: String) -> Unit
|
||||
) {
|
||||
var keyInput by remember { mutableStateOf("") }
|
||||
var error by remember { mutableStateOf<String?>(null) }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Restore Identity Key") },
|
||||
text = {
|
||||
Column {
|
||||
Text(
|
||||
text = "Paste your 64-character hex key below. This will replace your current identity.",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = keyInput,
|
||||
onValueChange = {
|
||||
keyInput = it.trim().lowercase()
|
||||
error = null
|
||||
},
|
||||
label = { Text("Identity Key (hex)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
isError = error != null
|
||||
)
|
||||
error?.let {
|
||||
Text(
|
||||
text = it,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
val cleaned = keyInput.replace("\\s".toRegex(), "")
|
||||
if (cleaned.length != 64 || !cleaned.all { it in '0'..'9' || it in 'a'..'f' }) {
|
||||
error = "Key must be exactly 64 hex characters"
|
||||
} else {
|
||||
onRestore(cleaned)
|
||||
}
|
||||
}
|
||||
) { Text("Restore") }
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) { Text("Cancel") }
|
||||
}
|
||||
)
|
||||
}
|
||||
4
android/app/src/main/res/xml/file_paths.xml
Normal file
4
android/app/src/main/res/xml/file_paths.xml
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<paths>
|
||||
<cache-path name="debug" path="." />
|
||||
</paths>
|
||||
4
android/build.gradle.kts
Normal file
4
android/build.gradle.kts
Normal file
@@ -0,0 +1,4 @@
|
||||
plugins {
|
||||
id("com.android.application") version "8.2.0" apply false
|
||||
id("org.jetbrains.kotlin.android") version "1.9.22" apply false
|
||||
}
|
||||
4
android/gradle.properties
Normal file
4
android/gradle.properties
Normal file
@@ -0,0 +1,4 @@
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
android.useAndroidX=true
|
||||
kotlin.code.style=official
|
||||
android.nonTransitiveRClass=true
|
||||
BIN
android/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
BIN
android/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
Binary file not shown.
6
android/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
6
android/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
|
||||
networkTimeout=10000
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
5
android/gradlew
vendored
Executable file
5
android/gradlew
vendored
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/sh
|
||||
# Gradle wrapper script
|
||||
APP_HOME=$(cd "$(dirname "$0")" && pwd)
|
||||
CLASSPATH="$APP_HOME/gradle/wrapper/gradle-wrapper.jar"
|
||||
exec java -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
|
||||
18
android/settings.gradle.kts
Normal file
18
android/settings.gradle.kts
Normal file
@@ -0,0 +1,18 @@
|
||||
pluginManagement {
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
gradlePluginPortal()
|
||||
}
|
||||
}
|
||||
|
||||
dependencyResolutionManagement {
|
||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
|
||||
rootProject.name = "WZPhone"
|
||||
include(":app")
|
||||
34
crates/wzp-android/Cargo.toml
Normal file
34
crates/wzp-android/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "wzp-android"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
description = "WarzonePhone Android native VoIP engine — Oboe audio, JNI bridge, call pipeline"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
wzp-proto = { workspace = true }
|
||||
wzp-codec = { workspace = true }
|
||||
wzp-fec = { workspace = true }
|
||||
wzp-crypto = { workspace = true }
|
||||
wzp-transport = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = "1"
|
||||
thiserror = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
anyhow = "1"
|
||||
libc = "0.2"
|
||||
jni = { version = "0.21", default-features = false }
|
||||
rand = { workspace = true }
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring"] }
|
||||
tracing-android = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1"
|
||||
154
crates/wzp-android/build.rs
Normal file
154
crates/wzp-android/build.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() {
|
||||
let target = std::env::var("TARGET").unwrap_or_default();
|
||||
|
||||
if target.contains("android") {
|
||||
// Override broken static getauxval from compiler-rt that crashes
|
||||
// in shared libraries. Must be compiled first to take link priority.
|
||||
cc::Build::new()
|
||||
.file("cpp/getauxval_fix.c")
|
||||
.compile("getauxval_fix");
|
||||
|
||||
let oboe_dir = fetch_oboe();
|
||||
match oboe_dir {
|
||||
Some(oboe_path) => {
|
||||
println!("cargo:warning=Building with Oboe from {:?}", oboe_path);
|
||||
|
||||
let mut build = cc::Build::new();
|
||||
build
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
// Use shared libc++ — avoids pulling in static libc stubs
|
||||
// that crash in shared libraries (getauxval, pthread_create, etc.)
|
||||
.cpp_link_stdlib(Some("c++_shared"))
|
||||
.include("cpp")
|
||||
.include(oboe_path.join("include"))
|
||||
.include(oboe_path.join("src"))
|
||||
.define("WZP_HAS_OBOE", None)
|
||||
.file("cpp/oboe_bridge.cpp");
|
||||
|
||||
// Compile all Oboe source files
|
||||
let src_dir = oboe_path.join("src");
|
||||
add_cpp_files_recursive(&mut build, &src_dir);
|
||||
|
||||
build.compile("oboe_bridge");
|
||||
}
|
||||
None => {
|
||||
println!("cargo:warning=Oboe not found, building with stub");
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
.cpp_link_stdlib(Some("c++_shared"))
|
||||
.file("cpp/oboe_stub.cpp")
|
||||
.include("cpp")
|
||||
.compile("oboe_bridge");
|
||||
}
|
||||
}
|
||||
|
||||
// Dynamic C++ runtime — libc++_shared.so must be in jniLibs alongside
|
||||
// libwzp_android.so. We copy it there from the NDK sysroot.
|
||||
//
|
||||
// WHY NOT STATIC: libc++_static.a + libc++abi.a transitively pull in
|
||||
// object files from libc.a (static libc) which contain broken stubs for
|
||||
// getauxval, __init_tcb, pthread_create, etc. These stubs only work in
|
||||
// statically-linked executables. In shared libraries loaded by dlopen(),
|
||||
// they SIGSEGV because the static libc init hasn't run.
|
||||
// Google's official recommendation: use libc++_shared.so for native libs.
|
||||
if let Ok(ndk) = std::env::var("ANDROID_NDK_HOME") {
|
||||
let arch = if target.contains("aarch64") {
|
||||
"aarch64-linux-android"
|
||||
} else if target.contains("armv7") {
|
||||
"arm-linux-androideabi"
|
||||
} else if target.contains("x86_64") {
|
||||
"x86_64-linux-android"
|
||||
} else {
|
||||
"aarch64-linux-android"
|
||||
};
|
||||
let lib_dir = format!(
|
||||
"{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}"
|
||||
);
|
||||
println!("cargo:rustc-link-search=native={lib_dir}");
|
||||
|
||||
// Copy libc++_shared.so to the jniLibs directory
|
||||
let shared_so = format!("{lib_dir}/libc++_shared.so");
|
||||
if std::path::Path::new(&shared_so).exists() {
|
||||
let jni_abi = if target.contains("aarch64") {
|
||||
"arm64-v8a"
|
||||
} else if target.contains("armv7") {
|
||||
"armeabi-v7a"
|
||||
} else {
|
||||
"arm64-v8a"
|
||||
};
|
||||
// Try to copy to the Gradle jniLibs directory
|
||||
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
|
||||
let jni_dir = format!(
|
||||
"{manifest}/../../android/app/src/main/jniLibs/{jni_abi}"
|
||||
);
|
||||
if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
|
||||
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
|
||||
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Oboe needs liblog and libOpenSLES from Android
|
||||
println!("cargo:rustc-link-lib=log");
|
||||
println!("cargo:rustc-link-lib=OpenSLES");
|
||||
} else {
|
||||
// Non-Android: always use stub
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
.file("cpp/oboe_stub.cpp")
|
||||
.include("cpp")
|
||||
.compile("oboe_bridge");
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively add all .cpp files from a directory to a cc::Build.
|
||||
fn add_cpp_files_recursive(build: &mut cc::Build, dir: &std::path::Path) {
|
||||
if !dir.is_dir() {
|
||||
return;
|
||||
}
|
||||
for entry in std::fs::read_dir(dir).unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
add_cpp_files_recursive(build, &path);
|
||||
} else if path.extension().map_or(false, |e| e == "cpp") {
|
||||
build.file(&path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to find or fetch Oboe headers + source.
|
||||
fn fetch_oboe() -> Option<PathBuf> {
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
let oboe_dir = out_dir.join("oboe");
|
||||
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
return Some(oboe_dir);
|
||||
}
|
||||
|
||||
let status = std::process::Command::new("git")
|
||||
.args([
|
||||
"clone",
|
||||
"--depth=1",
|
||||
"--branch=1.8.1",
|
||||
"https://github.com/google/oboe.git",
|
||||
oboe_dir.to_str().unwrap(),
|
||||
])
|
||||
.status();
|
||||
|
||||
match status {
|
||||
Ok(s) if s.success() => {
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
Some(oboe_dir)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
21
crates/wzp-android/cpp/getauxval_fix.c
Normal file
21
crates/wzp-android/cpp/getauxval_fix.c
Normal file
@@ -0,0 +1,21 @@
|
||||
// Override the broken static getauxval from compiler-rt/CRT.
|
||||
// The static version reads from __libc_auxv which is NULL in shared libs
|
||||
// loaded via dlopen, causing SIGSEGV in init_have_lse_atomics at load time.
|
||||
// This version calls the real bionic getauxval via dlsym.
|
||||
#ifdef __ANDROID__
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
|
||||
typedef unsigned long (*getauxval_fn)(unsigned long);
|
||||
|
||||
unsigned long getauxval(unsigned long type) {
|
||||
static getauxval_fn real_getauxval = (getauxval_fn)0;
|
||||
if (!real_getauxval) {
|
||||
real_getauxval = (getauxval_fn)dlsym((void*)-1L /* RTLD_DEFAULT */, "getauxval");
|
||||
if (!real_getauxval) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return real_getauxval(type);
|
||||
}
|
||||
#endif
|
||||
278
crates/wzp-android/cpp/oboe_bridge.cpp
Normal file
278
crates/wzp-android/cpp/oboe_bridge.cpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// Full Oboe implementation for Android
|
||||
// This file is compiled only when targeting Android
|
||||
|
||||
#include "oboe_bridge.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <oboe/Oboe.h>
|
||||
#include <android/log.h>
|
||||
#include <cstring>
|
||||
#include <atomic>
|
||||
|
||||
#define LOG_TAG "wzp-oboe"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ring buffer helpers (SPSC, lock-free)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static inline int32_t ring_available_read(const wzp_atomic_int* write_idx,
|
||||
const wzp_atomic_int* read_idx,
|
||||
int32_t capacity) {
|
||||
int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_acquire);
|
||||
int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed);
|
||||
int32_t avail = w - r;
|
||||
if (avail < 0) avail += capacity;
|
||||
return avail;
|
||||
}
|
||||
|
||||
static inline int32_t ring_available_write(const wzp_atomic_int* write_idx,
|
||||
const wzp_atomic_int* read_idx,
|
||||
int32_t capacity) {
|
||||
return capacity - 1 - ring_available_read(write_idx, read_idx, capacity);
|
||||
}
|
||||
|
||||
static inline void ring_write(int16_t* buf, int32_t capacity,
|
||||
wzp_atomic_int* write_idx, const wzp_atomic_int* read_idx,
|
||||
const int16_t* src, int32_t count) {
|
||||
int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_relaxed);
|
||||
for (int32_t i = 0; i < count; i++) {
|
||||
buf[w] = src[i];
|
||||
w++;
|
||||
if (w >= capacity) w = 0;
|
||||
}
|
||||
std::atomic_store_explicit(write_idx, w, std::memory_order_release);
|
||||
}
|
||||
|
||||
static inline void ring_read(int16_t* buf, int32_t capacity,
|
||||
const wzp_atomic_int* write_idx, wzp_atomic_int* read_idx,
|
||||
int16_t* dst, int32_t count) {
|
||||
int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed);
|
||||
for (int32_t i = 0; i < count; i++) {
|
||||
dst[i] = buf[r];
|
||||
r++;
|
||||
if (r >= capacity) r = 0;
|
||||
}
|
||||
std::atomic_store_explicit(read_idx, r, std::memory_order_release);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Global state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static std::shared_ptr<oboe::AudioStream> g_capture_stream;
|
||||
static std::shared_ptr<oboe::AudioStream> g_playout_stream;
|
||||
static const WzpOboeRings* g_rings = nullptr;
|
||||
static std::atomic<bool> g_running{false};
|
||||
static std::atomic<float> g_capture_latency_ms{0.0f};
|
||||
static std::atomic<float> g_playout_latency_ms{0.0f};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Capture callback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class CaptureCallback : public oboe::AudioStreamDataCallback {
|
||||
public:
|
||||
oboe::DataCallbackResult onAudioReady(
|
||||
oboe::AudioStream* stream,
|
||||
void* audioData,
|
||||
int32_t numFrames) override {
|
||||
if (!g_running.load(std::memory_order_relaxed) || !g_rings) {
|
||||
return oboe::DataCallbackResult::Stop;
|
||||
}
|
||||
|
||||
const int16_t* src = static_cast<const int16_t*>(audioData);
|
||||
int32_t avail = ring_available_write(g_rings->capture_write_idx,
|
||||
g_rings->capture_read_idx,
|
||||
g_rings->capture_capacity);
|
||||
int32_t to_write = (numFrames < avail) ? numFrames : avail;
|
||||
if (to_write > 0) {
|
||||
ring_write(g_rings->capture_buf, g_rings->capture_capacity,
|
||||
g_rings->capture_write_idx, g_rings->capture_read_idx,
|
||||
src, to_write);
|
||||
}
|
||||
|
||||
// Update latency estimate
|
||||
auto result = stream->calculateLatencyMillis();
|
||||
if (result) {
|
||||
g_capture_latency_ms.store(static_cast<float>(result.value()),
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
return oboe::DataCallbackResult::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Playout callback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class PlayoutCallback : public oboe::AudioStreamDataCallback {
|
||||
public:
|
||||
oboe::DataCallbackResult onAudioReady(
|
||||
oboe::AudioStream* stream,
|
||||
void* audioData,
|
||||
int32_t numFrames) override {
|
||||
if (!g_running.load(std::memory_order_relaxed) || !g_rings) {
|
||||
memset(audioData, 0, numFrames * sizeof(int16_t));
|
||||
return oboe::DataCallbackResult::Stop;
|
||||
}
|
||||
|
||||
int16_t* dst = static_cast<int16_t*>(audioData);
|
||||
int32_t avail = ring_available_read(g_rings->playout_write_idx,
|
||||
g_rings->playout_read_idx,
|
||||
g_rings->playout_capacity);
|
||||
int32_t to_read = (numFrames < avail) ? numFrames : avail;
|
||||
|
||||
if (to_read > 0) {
|
||||
ring_read(g_rings->playout_buf, g_rings->playout_capacity,
|
||||
g_rings->playout_write_idx, g_rings->playout_read_idx,
|
||||
dst, to_read);
|
||||
}
|
||||
// Fill remainder with silence on underrun
|
||||
if (to_read < numFrames) {
|
||||
memset(dst + to_read, 0, (numFrames - to_read) * sizeof(int16_t));
|
||||
}
|
||||
|
||||
// Update latency estimate
|
||||
auto result = stream->calculateLatencyMillis();
|
||||
if (result) {
|
||||
g_playout_latency_ms.store(static_cast<float>(result.value()),
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
return oboe::DataCallbackResult::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
static CaptureCallback g_capture_cb;
|
||||
static PlayoutCallback g_playout_cb;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public C API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
if (g_running.load(std::memory_order_relaxed)) {
|
||||
LOGW("wzp_oboe_start: already running");
|
||||
return -1;
|
||||
}
|
||||
|
||||
g_rings = rings;
|
||||
|
||||
// Build capture stream
|
||||
oboe::AudioStreamBuilder captureBuilder;
|
||||
captureBuilder.setDirection(oboe::Direction::Input)
|
||||
->setPerformanceMode(oboe::PerformanceMode::LowLatency)
|
||||
->setSharingMode(oboe::SharingMode::Exclusive)
|
||||
->setFormat(oboe::AudioFormat::I16)
|
||||
->setChannelCount(config->channel_count)
|
||||
->setSampleRate(config->sample_rate)
|
||||
->setFramesPerDataCallback(config->frames_per_burst)
|
||||
->setInputPreset(oboe::InputPreset::VoiceCommunication)
|
||||
->setDataCallback(&g_capture_cb);
|
||||
|
||||
oboe::Result result = captureBuilder.openStream(g_capture_stream);
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to open capture stream: %s", oboe::convertToText(result));
|
||||
return -2;
|
||||
}
|
||||
|
||||
// Build playout stream
|
||||
oboe::AudioStreamBuilder playoutBuilder;
|
||||
playoutBuilder.setDirection(oboe::Direction::Output)
|
||||
->setPerformanceMode(oboe::PerformanceMode::LowLatency)
|
||||
->setSharingMode(oboe::SharingMode::Exclusive)
|
||||
->setFormat(oboe::AudioFormat::I16)
|
||||
->setChannelCount(config->channel_count)
|
||||
->setSampleRate(config->sample_rate)
|
||||
->setFramesPerDataCallback(config->frames_per_burst)
|
||||
->setUsage(oboe::Usage::VoiceCommunication)
|
||||
->setDataCallback(&g_playout_cb);
|
||||
|
||||
result = playoutBuilder.openStream(g_playout_stream);
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to open playout stream: %s", oboe::convertToText(result));
|
||||
g_capture_stream->close();
|
||||
g_capture_stream.reset();
|
||||
return -3;
|
||||
}
|
||||
|
||||
g_running.store(true, std::memory_order_release);
|
||||
|
||||
// Start both streams
|
||||
result = g_capture_stream->requestStart();
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to start capture: %s", oboe::convertToText(result));
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -4;
|
||||
}
|
||||
|
||||
result = g_playout_stream->requestStart();
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to start playout: %s", oboe::convertToText(result));
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_capture_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -5;
|
||||
}
|
||||
|
||||
LOGI("Oboe started: sr=%d burst=%d ch=%d",
|
||||
config->sample_rate, config->frames_per_burst, config->channel_count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {
|
||||
g_running.store(false, std::memory_order_release);
|
||||
|
||||
if (g_capture_stream) {
|
||||
g_capture_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_capture_stream.reset();
|
||||
}
|
||||
if (g_playout_stream) {
|
||||
g_playout_stream->requestStop();
|
||||
g_playout_stream->close();
|
||||
g_playout_stream.reset();
|
||||
}
|
||||
|
||||
g_rings = nullptr;
|
||||
LOGI("Oboe stopped");
|
||||
}
|
||||
|
||||
float wzp_oboe_capture_latency_ms(void) {
|
||||
return g_capture_latency_ms.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
float wzp_oboe_playout_latency_ms(void) {
|
||||
return g_playout_latency_ms.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
int wzp_oboe_is_running(void) {
|
||||
return g_running.load(std::memory_order_relaxed) ? 1 : 0;
|
||||
}
|
||||
|
||||
#else
|
||||
// Non-Android fallback — should not be reached; oboe_stub.cpp is used instead.
|
||||
// Provide empty implementations just in case.
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
(void)config; (void)rings;
|
||||
return -99;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {}
|
||||
float wzp_oboe_capture_latency_ms(void) { return 0.0f; }
|
||||
float wzp_oboe_playout_latency_ms(void) { return 0.0f; }
|
||||
int wzp_oboe_is_running(void) { return 0; }
|
||||
|
||||
#endif // __ANDROID__
|
||||
43
crates/wzp-android/cpp/oboe_bridge.h
Normal file
43
crates/wzp-android/cpp/oboe_bridge.h
Normal file
@@ -0,0 +1,43 @@
|
||||
#ifndef WZP_OBOE_BRIDGE_H
|
||||
#define WZP_OBOE_BRIDGE_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <atomic>
|
||||
typedef std::atomic<int32_t> wzp_atomic_int;
|
||||
extern "C" {
|
||||
#else
|
||||
#include <stdatomic.h>
|
||||
typedef atomic_int wzp_atomic_int;
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
int32_t sample_rate;
|
||||
int32_t frames_per_burst;
|
||||
int32_t channel_count;
|
||||
} WzpOboeConfig;
|
||||
|
||||
typedef struct {
|
||||
int16_t* capture_buf;
|
||||
int32_t capture_capacity;
|
||||
wzp_atomic_int* capture_write_idx;
|
||||
wzp_atomic_int* capture_read_idx;
|
||||
|
||||
int16_t* playout_buf;
|
||||
int32_t playout_capacity;
|
||||
wzp_atomic_int* playout_write_idx;
|
||||
wzp_atomic_int* playout_read_idx;
|
||||
} WzpOboeRings;
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings);
|
||||
void wzp_oboe_stop(void);
|
||||
float wzp_oboe_capture_latency_ms(void);
|
||||
float wzp_oboe_playout_latency_ms(void);
|
||||
int wzp_oboe_is_running(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // WZP_OBOE_BRIDGE_H
|
||||
27
crates/wzp-android/cpp/oboe_stub.cpp
Normal file
27
crates/wzp-android/cpp/oboe_stub.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// Stub implementation for non-Android host builds (testing, cargo check, etc.)
|
||||
|
||||
#include "oboe_bridge.h"
|
||||
#include <stdio.h>
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
(void)config;
|
||||
(void)rings;
|
||||
fprintf(stderr, "wzp_oboe_start: stub (not on Android)\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {
|
||||
fprintf(stderr, "wzp_oboe_stop: stub (not on Android)\n");
|
||||
}
|
||||
|
||||
float wzp_oboe_capture_latency_ms(void) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float wzp_oboe_playout_latency_ms(void) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
int wzp_oboe_is_running(void) {
|
||||
return 0;
|
||||
}
|
||||
424
crates/wzp-android/src/audio_android.rs
Normal file
424
crates/wzp-android/src/audio_android.rs
Normal file
@@ -0,0 +1,424 @@
|
||||
//! Lock-free SPSC ring buffer audio backend for Android (Oboe).
|
||||
//!
|
||||
//! The ring buffers are shared between Rust and C++: the Oboe callbacks
|
||||
//! (running on a high-priority audio thread) read/write directly into
|
||||
//! the buffers via atomic indices, while the Rust codec thread on the
|
||||
//! other side does the same.
|
||||
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
|
||||
use tracing::info;
|
||||
#[allow(unused_imports)]
|
||||
use tracing::warn;
|
||||
|
||||
/// Number of samples per 20 ms frame at 48 kHz mono.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
/// Default ring buffer capacity: 8 frames = 160 ms at 48 kHz.
|
||||
const RING_CAPACITY: usize = 7680;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FFI declarations matching oboe_bridge.h
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[repr(C)]
|
||||
#[allow(non_snake_case)]
|
||||
struct WzpOboeConfig {
|
||||
sample_rate: i32,
|
||||
frames_per_burst: i32,
|
||||
channel_count: i32,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[allow(non_snake_case)]
|
||||
struct WzpOboeRings {
|
||||
capture_buf: *mut i16,
|
||||
capture_capacity: i32,
|
||||
capture_write_idx: *mut AtomicI32,
|
||||
capture_read_idx: *mut AtomicI32,
|
||||
|
||||
playout_buf: *mut i16,
|
||||
playout_capacity: i32,
|
||||
playout_write_idx: *mut AtomicI32,
|
||||
playout_read_idx: *mut AtomicI32,
|
||||
}
|
||||
|
||||
unsafe impl Send for WzpOboeRings {}
|
||||
unsafe impl Sync for WzpOboeRings {}
|
||||
|
||||
unsafe extern "C" {
|
||||
fn wzp_oboe_start(config: *const WzpOboeConfig, rings: *const WzpOboeRings) -> i32;
|
||||
fn wzp_oboe_stop();
|
||||
fn wzp_oboe_capture_latency_ms() -> f32;
|
||||
fn wzp_oboe_playout_latency_ms() -> f32;
|
||||
fn wzp_oboe_is_running() -> i32;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SPSC Ring Buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Single-producer single-consumer lock-free ring buffer.
|
||||
///
|
||||
/// The producer calls `write()` and the consumer calls `read()`.
|
||||
/// Atomics use acquire/release ordering to ensure correct visibility
|
||||
/// across the Oboe audio thread and the Rust codec thread.
|
||||
pub struct RingBuffer {
|
||||
buf: Vec<i16>,
|
||||
capacity: usize,
|
||||
write_idx: AtomicI32,
|
||||
read_idx: AtomicI32,
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
/// Create a new ring buffer with the given capacity (in samples).
|
||||
///
|
||||
/// The actual usable capacity is `capacity - 1` to distinguish
|
||||
/// full from empty.
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buf: vec![0i16; capacity],
|
||||
capacity,
|
||||
write_idx: AtomicI32::new(0),
|
||||
read_idx: AtomicI32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples available to read.
|
||||
pub fn available_read(&self) -> usize {
|
||||
let w = self.write_idx.load(Ordering::Acquire);
|
||||
let r = self.read_idx.load(Ordering::Relaxed);
|
||||
let avail = w - r;
|
||||
if avail < 0 {
|
||||
(avail + self.capacity as i32) as usize
|
||||
} else {
|
||||
avail as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples that can be written before the buffer is full.
|
||||
pub fn available_write(&self) -> usize {
|
||||
self.capacity - 1 - self.available_read()
|
||||
}
|
||||
|
||||
/// Write samples into the ring buffer (producer side).
|
||||
///
|
||||
/// Returns the number of samples actually written (may be less than
|
||||
/// `data.len()` if the buffer is nearly full).
|
||||
pub fn write(&self, data: &[i16]) -> usize {
|
||||
let avail = self.available_write();
|
||||
let count = data.len().min(avail);
|
||||
if count == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut w = self.write_idx.load(Ordering::Relaxed) as usize;
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr() as *mut i16;
|
||||
|
||||
for i in 0..count {
|
||||
// SAFETY: w is always in [0, capacity) and we are the sole producer.
|
||||
unsafe {
|
||||
*buf_ptr.add(w) = data[i];
|
||||
}
|
||||
w += 1;
|
||||
if w >= cap {
|
||||
w = 0;
|
||||
}
|
||||
}
|
||||
|
||||
self.write_idx.store(w as i32, Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Read samples from the ring buffer (consumer side).
|
||||
///
|
||||
/// Returns the number of samples actually read (may be less than
|
||||
/// `out.len()` if the buffer doesn't have enough data).
|
||||
pub fn read(&self, out: &mut [i16]) -> usize {
|
||||
let avail = self.available_read();
|
||||
let count = out.len().min(avail);
|
||||
if count == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut r = self.read_idx.load(Ordering::Relaxed) as usize;
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr();
|
||||
|
||||
for i in 0..count {
|
||||
// SAFETY: r is always in [0, capacity) and we are the sole consumer.
|
||||
unsafe {
|
||||
out[i] = *buf_ptr.add(r);
|
||||
}
|
||||
r += 1;
|
||||
if r >= cap {
|
||||
r = 0;
|
||||
}
|
||||
}
|
||||
|
||||
self.read_idx.store(r as i32, Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Get a raw pointer to the buffer data (for FFI).
|
||||
fn buf_ptr(&self) -> *mut i16 {
|
||||
self.buf.as_ptr() as *mut i16
|
||||
}
|
||||
|
||||
/// Get a raw pointer to the write index atomic (for FFI).
|
||||
fn write_idx_ptr(&self) -> *mut AtomicI32 {
|
||||
&self.write_idx as *const AtomicI32 as *mut AtomicI32
|
||||
}
|
||||
|
||||
/// Get a raw pointer to the read index atomic (for FFI).
|
||||
fn read_idx_ptr(&self) -> *mut AtomicI32 {
|
||||
&self.read_idx as *const AtomicI32 as *mut AtomicI32
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: The ring buffer is designed for SPSC use where producer and consumer
|
||||
// are on different threads. The atomic indices provide the synchronization.
|
||||
unsafe impl Send for RingBuffer {}
|
||||
unsafe impl Sync for RingBuffer {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Oboe Backend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Oboe-based audio backend for Android.
|
||||
///
|
||||
/// Owns two SPSC ring buffers (capture and playout) that are shared with
|
||||
/// the C++ Oboe callbacks via raw pointers. The Oboe callbacks run on
|
||||
/// high-priority audio threads managed by the Android audio system.
|
||||
pub struct OboeBackend {
|
||||
capture_ring: RingBuffer,
|
||||
playout_ring: RingBuffer,
|
||||
started: bool,
|
||||
}
|
||||
|
||||
impl OboeBackend {
|
||||
/// Create a new backend with default ring buffer sizes (160 ms each).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
capture_ring: RingBuffer::new(RING_CAPACITY),
|
||||
playout_ring: RingBuffer::new(RING_CAPACITY),
|
||||
started: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start Oboe audio streams.
|
||||
///
|
||||
/// This sets up the ring buffer pointers and calls into the C++ layer
|
||||
/// to open and start the capture and playout Oboe streams.
|
||||
pub fn start(&mut self) -> Result<(), anyhow::Error> {
|
||||
if self.started {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let config = WzpOboeConfig {
|
||||
sample_rate: 48_000,
|
||||
frames_per_burst: FRAME_SAMPLES as i32,
|
||||
channel_count: 1,
|
||||
};
|
||||
|
||||
let rings = WzpOboeRings {
|
||||
capture_buf: self.capture_ring.buf_ptr(),
|
||||
capture_capacity: self.capture_ring.capacity as i32,
|
||||
capture_write_idx: self.capture_ring.write_idx_ptr(),
|
||||
capture_read_idx: self.capture_ring.read_idx_ptr(),
|
||||
|
||||
playout_buf: self.playout_ring.buf_ptr(),
|
||||
playout_capacity: self.playout_ring.capacity as i32,
|
||||
playout_write_idx: self.playout_ring.write_idx_ptr(),
|
||||
playout_read_idx: self.playout_ring.read_idx_ptr(),
|
||||
};
|
||||
|
||||
let ret = unsafe { wzp_oboe_start(&config, &rings) };
|
||||
if ret != 0 {
|
||||
return Err(anyhow::anyhow!("wzp_oboe_start failed with code {}", ret));
|
||||
}
|
||||
|
||||
self.started = true;
|
||||
info!("Oboe backend started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop Oboe audio streams.
|
||||
pub fn stop(&mut self) {
|
||||
if !self.started {
|
||||
return;
|
||||
}
|
||||
unsafe { wzp_oboe_stop() };
|
||||
self.started = false;
|
||||
info!("Oboe backend stopped");
|
||||
}
|
||||
|
||||
/// Read captured audio samples from the capture ring buffer.
|
||||
///
|
||||
/// Returns the number of samples actually read. The caller should
|
||||
/// provide a buffer of at least `FRAME_SAMPLES` (960) samples.
|
||||
pub fn read_capture(&self, out: &mut [i16]) -> usize {
|
||||
self.capture_ring.read(out)
|
||||
}
|
||||
|
||||
/// Write audio samples to the playout ring buffer.
|
||||
///
|
||||
/// Returns the number of samples actually written.
|
||||
pub fn write_playout(&self, samples: &[i16]) -> usize {
|
||||
self.playout_ring.write(samples)
|
||||
}
|
||||
|
||||
/// Get the current capture latency in milliseconds (from Oboe).
|
||||
#[allow(unused)]
|
||||
pub fn capture_latency_ms(&self) -> f32 {
|
||||
unsafe { wzp_oboe_capture_latency_ms() }
|
||||
}
|
||||
|
||||
/// Get the current playout latency in milliseconds (from Oboe).
|
||||
#[allow(unused)]
|
||||
pub fn playout_latency_ms(&self) -> f32 {
|
||||
unsafe { wzp_oboe_playout_latency_ms() }
|
||||
}
|
||||
|
||||
/// Check if the Oboe streams are currently running.
|
||||
#[allow(unused)]
|
||||
pub fn is_running(&self) -> bool {
|
||||
unsafe { wzp_oboe_is_running() != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OboeBackend {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Thread affinity / priority helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Pin the current thread to the highest-numbered CPU cores (big cores on
|
||||
/// ARM big.LITTLE architectures). Falls back silently on failure.
|
||||
#[allow(unused)]
|
||||
pub fn pin_to_big_core() {
|
||||
#[cfg(target_os = "android")]
|
||||
{
|
||||
unsafe {
|
||||
let num_cpus = libc::sysconf(libc::_SC_NPROCESSORS_ONLN);
|
||||
if num_cpus <= 0 {
|
||||
warn!("pin_to_big_core: could not determine CPU count");
|
||||
return;
|
||||
}
|
||||
let num_cpus = num_cpus as usize;
|
||||
|
||||
// Target the upper half of CPUs (big cores on most big.LITTLE SoCs)
|
||||
let start = num_cpus / 2;
|
||||
let mut set: libc::cpu_set_t = std::mem::zeroed();
|
||||
libc::CPU_ZERO(&mut set);
|
||||
for cpu in start..num_cpus {
|
||||
libc::CPU_SET(cpu, &mut set);
|
||||
}
|
||||
|
||||
let ret = libc::sched_setaffinity(
|
||||
0, // current thread
|
||||
std::mem::size_of::<libc::cpu_set_t>(),
|
||||
&set,
|
||||
);
|
||||
if ret != 0 {
|
||||
warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error());
|
||||
} else {
|
||||
info!(start, num_cpus, "pinned to big cores");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "android"))]
|
||||
{
|
||||
// No-op on non-Android
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to set SCHED_FIFO real-time priority for the current thread.
|
||||
/// Falls back silently on failure (requires appropriate permissions on Android).
|
||||
#[allow(unused)]
|
||||
pub fn set_realtime_priority() {
|
||||
#[cfg(target_os = "android")]
|
||||
{
|
||||
unsafe {
|
||||
let param = libc::sched_param {
|
||||
sched_priority: 2, // Low RT priority — enough for audio, safe
|
||||
};
|
||||
let ret = libc::sched_setscheduler(0, libc::SCHED_FIFO, ¶m);
|
||||
if ret != 0 {
|
||||
warn!(
|
||||
"sched_setscheduler(SCHED_FIFO) failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
} else {
|
||||
info!("set SCHED_FIFO priority 2");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "android"))]
|
||||
{
|
||||
// No-op on non-Android
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_write_read() {
|
||||
let ring = RingBuffer::new(16);
|
||||
let data = [1i16, 2, 3, 4, 5];
|
||||
assert_eq!(ring.write(&data), 5);
|
||||
assert_eq!(ring.available_read(), 5);
|
||||
|
||||
let mut out = [0i16; 5];
|
||||
assert_eq!(ring.read(&mut out), 5);
|
||||
assert_eq!(out, [1, 2, 3, 4, 5]);
|
||||
assert_eq!(ring.available_read(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_wraparound() {
|
||||
let ring = RingBuffer::new(8);
|
||||
let data = [10i16, 20, 30, 40, 50, 60]; // 6 samples, capacity 8 (usable 7)
|
||||
assert_eq!(ring.write(&data), 6);
|
||||
|
||||
let mut out = [0i16; 4];
|
||||
assert_eq!(ring.read(&mut out), 4);
|
||||
assert_eq!(out, [10, 20, 30, 40]);
|
||||
|
||||
// Now write more, which should wrap around
|
||||
let data2 = [70i16, 80, 90, 100];
|
||||
assert_eq!(ring.write(&data2), 4);
|
||||
|
||||
let mut out2 = [0i16; 6];
|
||||
assert_eq!(ring.read(&mut out2), 6);
|
||||
assert_eq!(out2, [50, 60, 70, 80, 90, 100]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_full() {
|
||||
let ring = RingBuffer::new(4); // usable capacity = 3
|
||||
let data = [1i16, 2, 3, 4, 5];
|
||||
assert_eq!(ring.write(&data), 3); // Only 3 fit
|
||||
assert_eq!(ring.available_write(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oboe_backend_stub_start_stop() {
|
||||
let mut backend = OboeBackend::new();
|
||||
backend.start().expect("stub start should succeed");
|
||||
assert!(backend.started);
|
||||
backend.stop();
|
||||
assert!(!backend.started);
|
||||
}
|
||||
}
|
||||
91
crates/wzp-android/src/audio_ring.rs
Normal file
91
crates/wzp-android/src/audio_ring.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
//! Lock-free SPSC ring buffers for audio PCM transfer between
|
||||
//! Kotlin AudioRecord/AudioTrack threads and the Rust engine.
|
||||
//!
|
||||
//! These use a simple spin-free design: the producer writes and advances
|
||||
//! a write cursor, the consumer reads and advances a read cursor.
|
||||
//! Both cursors are atomic so no mutex is needed.
|
||||
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Ring buffer capacity in i16 samples.
|
||||
/// 960 samples * 10 frames = ~200ms of audio at 48kHz mono.
|
||||
const RING_CAPACITY: usize = 960 * 10;
|
||||
|
||||
/// Lock-free single-producer single-consumer ring buffer for i16 PCM samples.
|
||||
pub struct AudioRing {
|
||||
buf: Box<[i16; RING_CAPACITY]>,
|
||||
write_pos: AtomicUsize,
|
||||
read_pos: AtomicUsize,
|
||||
}
|
||||
|
||||
// SAFETY: AudioRing is designed for SPSC — one thread writes, one reads.
|
||||
// The atomics ensure visibility. The buffer itself is never accessed
|
||||
// from the same index by both threads simultaneously because the
|
||||
// producer only writes to positions between write_pos and read_pos,
|
||||
// and the consumer only reads from positions between read_pos and write_pos.
|
||||
unsafe impl Send for AudioRing {}
|
||||
unsafe impl Sync for AudioRing {}
|
||||
|
||||
impl AudioRing {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buf: Box::new([0i16; RING_CAPACITY]),
|
||||
write_pos: AtomicUsize::new(0),
|
||||
read_pos: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples available to read.
|
||||
pub fn available(&self) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let r = self.read_pos.load(Ordering::Acquire);
|
||||
w.wrapping_sub(r)
|
||||
}
|
||||
|
||||
/// Number of samples that can be written without overwriting.
|
||||
pub fn free_space(&self) -> usize {
|
||||
RING_CAPACITY - self.available()
|
||||
}
|
||||
|
||||
/// Write samples into the ring. Returns number of samples written.
|
||||
/// Drops oldest samples if the ring is full.
|
||||
pub fn write(&self, samples: &[i16]) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Relaxed);
|
||||
let count = samples.len().min(RING_CAPACITY);
|
||||
|
||||
for i in 0..count {
|
||||
let idx = (w + i) % RING_CAPACITY;
|
||||
// SAFETY: We're the only writer, and the reader won't read
|
||||
// past read_pos which we haven't advanced past yet.
|
||||
unsafe {
|
||||
let ptr = self.buf.as_ptr() as *mut i16;
|
||||
*ptr.add(idx) = samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
self.write_pos.store(w.wrapping_add(count), Ordering::Release);
|
||||
|
||||
// If we overwrote unread data, advance read_pos
|
||||
if self.available() > RING_CAPACITY {
|
||||
let new_read = self.write_pos.load(Ordering::Relaxed).wrapping_sub(RING_CAPACITY);
|
||||
self.read_pos.store(new_read, Ordering::Release);
|
||||
}
|
||||
|
||||
count
|
||||
}
|
||||
|
||||
/// Read samples from the ring into `out`. Returns number of samples read.
|
||||
pub fn read(&self, out: &mut [i16]) -> usize {
|
||||
let avail = self.available();
|
||||
let count = out.len().min(avail);
|
||||
|
||||
let r = self.read_pos.load(Ordering::Relaxed);
|
||||
for i in 0..count {
|
||||
let idx = (r + i) % RING_CAPACITY;
|
||||
out[i] = unsafe { *self.buf.as_ptr().add(idx) };
|
||||
}
|
||||
|
||||
self.read_pos.store(r.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
}
|
||||
15
crates/wzp-android/src/commands.rs
Normal file
15
crates/wzp-android/src/commands.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Engine commands sent from the JNI/UI thread to the engine.
|
||||
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
/// Commands that can be sent to the running engine.
|
||||
pub enum EngineCommand {
|
||||
/// Mute or unmute the microphone.
|
||||
SetMute(bool),
|
||||
/// Enable or disable speaker (loudspeaker) mode.
|
||||
SetSpeaker(bool),
|
||||
/// Force a specific quality profile (overrides adaptive logic).
|
||||
ForceProfile(QualityProfile),
|
||||
/// Stop the call and shut down the engine.
|
||||
Stop,
|
||||
}
|
||||
686
crates/wzp-android/src/engine.rs
Normal file
686
crates/wzp-android/src/engine.rs
Normal file
@@ -0,0 +1,686 @@
|
||||
//! Engine orchestrator — manages the call lifecycle.
|
||||
//!
|
||||
//! IMPORTANT: On Android, pthread_create crashes in shared libraries due to
|
||||
//! static bionic stubs in the Rust std prebuilt rlibs. ALL work must happen
|
||||
//! on the JNI calling thread or via the tokio current_thread runtime.
|
||||
//! No std::thread::spawn or tokio multi_thread allowed.
|
||||
//!
|
||||
//! Audio capture and playout happen on Kotlin JVM threads via AudioRecord
|
||||
//! and AudioTrack. PCM samples are transferred through lock-free ring buffers.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tracing::{error, info, warn};
|
||||
use wzp_codec::agc::AutoGainControl;
|
||||
use wzp_codec::opus_dec::OpusDecoder;
|
||||
use wzp_codec::opus_enc::OpusEncoder;
|
||||
use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::{
|
||||
AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
|
||||
MediaHeader, MediaPacket, MediaTransport, QualityProfile, SignalMessage,
|
||||
};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
use crate::commands::EngineCommand;
|
||||
use crate::stats::{CallState, CallStats};
|
||||
|
||||
/// Opus frame size at 48kHz mono, 20ms = 960 samples.
|
||||
const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
/// Configuration to start a call.
|
||||
pub struct CallStartConfig {
|
||||
pub profile: QualityProfile,
|
||||
pub relay_addr: String,
|
||||
pub room: String,
|
||||
pub auth_token: Vec<u8>,
|
||||
pub identity_seed: [u8; 32],
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for CallStartConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
profile: QualityProfile::GOOD,
|
||||
relay_addr: String::new(),
|
||||
room: String::new(),
|
||||
auth_token: Vec::new(),
|
||||
identity_seed: [0u8; 32],
|
||||
alias: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct EngineState {
|
||||
pub running: AtomicBool,
|
||||
pub muted: AtomicBool,
|
||||
pub stats: Mutex<CallStats>,
|
||||
pub command_tx: std::sync::mpsc::Sender<EngineCommand>,
|
||||
pub command_rx: Mutex<Option<std::sync::mpsc::Receiver<EngineCommand>>>,
|
||||
/// Ring buffer: Kotlin AudioRecord → Rust encoder
|
||||
pub capture_ring: AudioRing,
|
||||
/// Ring buffer: Rust decoder → Kotlin AudioTrack
|
||||
pub playout_ring: AudioRing,
|
||||
/// Current audio level (RMS) for UI display, updated by capture path.
|
||||
pub audio_level_rms: AtomicU32,
|
||||
/// QUIC transport handle — stored so stop_call() can close it immediately,
|
||||
/// triggering relay-side leave + RoomUpdate broadcast.
|
||||
pub quic_transport: Mutex<Option<Arc<wzp_transport::QuinnTransport>>>,
|
||||
}
|
||||
|
||||
pub struct WzpEngine {
|
||||
pub(crate) state: Arc<EngineState>,
|
||||
tokio_runtime: Option<tokio::runtime::Runtime>,
|
||||
call_start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl WzpEngine {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let state = Arc::new(EngineState {
|
||||
running: AtomicBool::new(false),
|
||||
muted: AtomicBool::new(false),
|
||||
stats: Mutex::new(CallStats::default()),
|
||||
command_tx: tx,
|
||||
command_rx: Mutex::new(Some(rx)),
|
||||
capture_ring: AudioRing::new(),
|
||||
playout_ring: AudioRing::new(),
|
||||
audio_level_rms: AtomicU32::new(0),
|
||||
quic_transport: Mutex::new(None),
|
||||
});
|
||||
Self {
|
||||
state,
|
||||
tokio_runtime: None,
|
||||
call_start: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_call(&mut self, config: CallStartConfig) -> Result<(), anyhow::Error> {
|
||||
if self.state.running.load(Ordering::Acquire) {
|
||||
return Err(anyhow::anyhow!("call already active"));
|
||||
}
|
||||
|
||||
{
|
||||
let mut stats = self.state.stats.lock().unwrap();
|
||||
*stats = CallStats {
|
||||
state: CallState::Connecting,
|
||||
..Default::default()
|
||||
};
|
||||
}
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let relay_addr: SocketAddr = config.relay_addr.parse().map_err(|e| {
|
||||
anyhow::anyhow!("invalid relay address '{}': {e}", config.relay_addr)
|
||||
})?;
|
||||
|
||||
let room = config.room.clone();
|
||||
let identity_seed = config.identity_seed;
|
||||
let profile = config.profile;
|
||||
let alias = config.alias.clone();
|
||||
let state = self.state.clone();
|
||||
|
||||
self.state.running.store(true, Ordering::Release);
|
||||
self.call_start = Some(Instant::now());
|
||||
|
||||
let state_clone = state.clone();
|
||||
runtime.block_on(async move {
|
||||
if let Err(e) = run_call(relay_addr, &room, &identity_seed, profile, alias.as_deref(), state_clone).await
|
||||
{
|
||||
error!("call failed: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
state.running.store(false, Ordering::Release);
|
||||
{
|
||||
let mut stats = state.stats.lock().unwrap();
|
||||
stats.state = CallState::Closed;
|
||||
}
|
||||
|
||||
self.tokio_runtime = Some(runtime);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop_call(&mut self) {
|
||||
info!("stop_call: setting running=false");
|
||||
self.state.running.store(false, Ordering::Release);
|
||||
// Close QUIC connection — this wakes up all blocked recv/send futures
|
||||
// inside block_on(run_call(...)) on the JNI thread. run_call will then
|
||||
// wait up to 500ms for the peer to acknowledge the close before returning.
|
||||
if let Some(transport) = self.state.quic_transport.lock().unwrap().take() {
|
||||
info!("stop_call: closing QUIC connection");
|
||||
transport.close_now();
|
||||
}
|
||||
let _ = self.state.command_tx.send(EngineCommand::Stop);
|
||||
// Note: the runtime is still blocked in block_on(run_call(...)) on the
|
||||
// start_call thread. Once run_call exits (triggered by running=false +
|
||||
// connection close above), block_on returns and stores the runtime in
|
||||
// self.tokio_runtime. We don't need to shut it down here.
|
||||
if let Some(rt) = self.tokio_runtime.take() {
|
||||
rt.shutdown_timeout(std::time::Duration::from_millis(100));
|
||||
}
|
||||
self.call_start = None;
|
||||
info!("stop_call: done");
|
||||
}
|
||||
|
||||
pub fn set_mute(&self, muted: bool) {
|
||||
self.state.muted.store(muted, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_speaker(&self, _enabled: bool) {}
|
||||
|
||||
pub fn force_profile(&self, _profile: QualityProfile) {}
|
||||
|
||||
pub fn get_stats(&self) -> CallStats {
|
||||
let mut stats = self.state.stats.lock().unwrap().clone();
|
||||
if let Some(start) = self.call_start {
|
||||
stats.duration_secs = start.elapsed().as_secs_f64();
|
||||
}
|
||||
stats.audio_level = self.state.audio_level_rms.load(Ordering::Relaxed);
|
||||
stats
|
||||
}
|
||||
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.state.running.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub fn write_audio(&self, samples: &[i16]) -> usize {
|
||||
if self.state.muted.load(Ordering::Relaxed) {
|
||||
return samples.len();
|
||||
}
|
||||
// Compute RMS for audio level display
|
||||
if !samples.is_empty() {
|
||||
let sum_sq: f64 = samples.iter().map(|&s| (s as f64) * (s as f64)).sum();
|
||||
let rms = (sum_sq / samples.len() as f64).sqrt() as u32;
|
||||
self.state.audio_level_rms.store(rms, Ordering::Relaxed);
|
||||
}
|
||||
self.state.capture_ring.write(samples)
|
||||
}
|
||||
|
||||
pub fn read_audio(&self, out: &mut [i16]) -> usize {
|
||||
self.state.playout_ring.read(out)
|
||||
}
|
||||
|
||||
pub fn destroy(mut self) {
|
||||
self.stop_call();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WzpEngine {
|
||||
fn drop(&mut self) {
|
||||
self.stop_call();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the full call lifecycle: connect, handshake, send/recv media with Opus + FEC.
|
||||
async fn run_call(
|
||||
relay_addr: SocketAddr,
|
||||
room: &str,
|
||||
identity_seed: &[u8; 32],
|
||||
profile: QualityProfile,
|
||||
alias: Option<&str>,
|
||||
state: Arc<EngineState>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
|
||||
let sni = if room.is_empty() { "android" } else { room };
|
||||
info!(%relay_addr, sni, "connecting to relay...");
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(&endpoint, relay_addr, sni, client_cfg).await?;
|
||||
info!("QUIC connected to relay");
|
||||
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
|
||||
// Store transport handle so stop_call() can close the connection immediately
|
||||
*state.quic_transport.lock().unwrap() = Some(transport.clone());
|
||||
|
||||
// Crypto handshake
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(identity_seed);
|
||||
let ephemeral_pub = kx.generate_ephemeral();
|
||||
let identity_pub = kx.identity_public_key();
|
||||
|
||||
let mut sign_data = Vec::with_capacity(42);
|
||||
sign_data.extend_from_slice(&ephemeral_pub);
|
||||
sign_data.extend_from_slice(b"call-offer");
|
||||
let signature = kx.sign(&sign_data);
|
||||
|
||||
let offer = SignalMessage::CallOffer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles: vec![
|
||||
QualityProfile::GOOD,
|
||||
QualityProfile::DEGRADED,
|
||||
QualityProfile::CATASTROPHIC,
|
||||
],
|
||||
alias: alias.map(|s| s.to_string()),
|
||||
};
|
||||
transport.send_signal(&offer).await?;
|
||||
info!("CallOffer sent, waiting for CallAnswer...");
|
||||
|
||||
let answer = transport
|
||||
.recv_signal()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before CallAnswer"))?;
|
||||
|
||||
let relay_ephemeral_pub = match answer {
|
||||
SignalMessage::CallAnswer { ephemeral_pub, .. } => ephemeral_pub,
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallAnswer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let _session = kx.derive_session(&relay_ephemeral_pub)?;
|
||||
info!("handshake complete, call active");
|
||||
|
||||
{
|
||||
let mut stats = state.stats.lock().unwrap();
|
||||
stats.state = CallState::Active;
|
||||
}
|
||||
|
||||
// Initialize Opus codec
|
||||
let mut encoder =
|
||||
OpusEncoder::new(profile).map_err(|e| anyhow::anyhow!("opus encoder init: {e}"))?;
|
||||
let mut decoder =
|
||||
OpusDecoder::new(profile).map_err(|e| anyhow::anyhow!("opus decoder init: {e}"))?;
|
||||
|
||||
// Initialize FEC encoder/decoder
|
||||
let mut fec_enc = wzp_fec::create_encoder(&profile);
|
||||
let mut fec_dec = wzp_fec::create_decoder(&profile);
|
||||
|
||||
// AGC: normalize volume on both capture and playout paths
|
||||
let mut capture_agc = AutoGainControl::new();
|
||||
let mut playout_agc = AutoGainControl::new();
|
||||
|
||||
info!(
|
||||
fec_ratio = profile.fec_ratio,
|
||||
frames_per_block = profile.frames_per_block,
|
||||
"codec + FEC + AGC initialized (48kHz mono, 20ms frames)"
|
||||
);
|
||||
|
||||
let seq = AtomicU16::new(0);
|
||||
let ts = AtomicU32::new(0);
|
||||
let transport_recv = transport.clone();
|
||||
|
||||
// Pre-allocate buffers
|
||||
let mut capture_buf = vec![0i16; FRAME_SAMPLES];
|
||||
let mut encode_buf = vec![0u8; encoder.max_frame_bytes()];
|
||||
let mut frame_in_block: u8 = 0;
|
||||
let mut block_id: u8 = 0;
|
||||
|
||||
// Send task: capture ring → Opus encode → FEC → MediaPackets
|
||||
//
|
||||
// IMPORTANT: send_media() uses quinn's send_datagram() which is
|
||||
// synchronous and returns Err(Blocked) when the congestion window
|
||||
// is full. We MUST NOT break on send errors — that would kill the
|
||||
// entire call. Instead we drop the packet and keep going.
|
||||
let send_task = async {
|
||||
info!("send task started (Opus + RaptorQ FEC)");
|
||||
let mut send_errors: u64 = 0;
|
||||
let mut last_send_error_log = Instant::now();
|
||||
let mut last_stats_log = Instant::now();
|
||||
let mut frames_sent: u64 = 0;
|
||||
let mut frames_dropped: u64 = 0;
|
||||
loop {
|
||||
if !state.running.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
let avail = state.capture_ring.available();
|
||||
if avail < FRAME_SAMPLES {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let read = state.capture_ring.read(&mut capture_buf);
|
||||
if read < FRAME_SAMPLES {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mute: zero out the buffer so Opus encodes silence.
|
||||
// We still read from the ring to prevent it from filling up.
|
||||
if state.muted.load(Ordering::Relaxed) {
|
||||
capture_buf.fill(0);
|
||||
}
|
||||
|
||||
// AGC: normalize capture volume before encoding
|
||||
capture_agc.process_frame(&mut capture_buf);
|
||||
|
||||
// Opus encode
|
||||
let encoded_len = match encoder.encode(&capture_buf, &mut encode_buf) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
warn!("opus encode error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let encoded = &encode_buf[..encoded_len];
|
||||
|
||||
// Build source packet
|
||||
let s = seq.fetch_add(1, Ordering::Relaxed);
|
||||
let t = ts.fetch_add(FRAME_SAMPLES as u32, Ordering::Relaxed);
|
||||
|
||||
let source_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(profile.fec_ratio),
|
||||
seq: s,
|
||||
timestamp: t,
|
||||
fec_block: block_id,
|
||||
fec_symbol: frame_in_block,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::copy_from_slice(encoded),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
// Send source packet — drop on error, never break
|
||||
if let Err(e) = transport.send_media(&source_pkt).await {
|
||||
send_errors += 1;
|
||||
frames_dropped += 1;
|
||||
// Log first few errors, then throttle to once per second
|
||||
if send_errors <= 3 || last_send_error_log.elapsed().as_secs() >= 1 {
|
||||
warn!(
|
||||
seq = s,
|
||||
send_errors,
|
||||
frames_dropped,
|
||||
"send_media error (dropping packet): {e}"
|
||||
);
|
||||
last_send_error_log = Instant::now();
|
||||
}
|
||||
// Don't feed to FEC either — the source is lost
|
||||
continue;
|
||||
}
|
||||
frames_sent += 1;
|
||||
|
||||
// Feed encoded frame to FEC encoder
|
||||
if let Err(e) = fec_enc.add_source_symbol(encoded) {
|
||||
warn!("fec add_source error: {e}");
|
||||
}
|
||||
frame_in_block += 1;
|
||||
|
||||
// When block is full, generate repair packets
|
||||
if frame_in_block >= profile.frames_per_block {
|
||||
match fec_enc.generate_repair(profile.fec_ratio) {
|
||||
Ok(repairs) => {
|
||||
let repair_count = repairs.len();
|
||||
for (sym_idx, repair_data) in repairs {
|
||||
let rs = seq.fetch_add(1, Ordering::Relaxed);
|
||||
let repair_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: true,
|
||||
codec_id: profile.codec,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
|
||||
profile.fec_ratio,
|
||||
),
|
||||
seq: rs,
|
||||
timestamp: t,
|
||||
fec_block: block_id,
|
||||
fec_symbol: sym_idx,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(repair_data),
|
||||
quality_report: None,
|
||||
};
|
||||
// Drop repair packets on error — never break
|
||||
if let Err(_e) = transport.send_media(&repair_pkt).await {
|
||||
send_errors += 1;
|
||||
frames_dropped += 1;
|
||||
// Don't log every repair failure — source error log covers it
|
||||
}
|
||||
}
|
||||
if repair_count > 0 && (block_id % 50 == 0 || block_id == 0) {
|
||||
info!(
|
||||
block_id,
|
||||
repair_count,
|
||||
fec_ratio = profile.fec_ratio,
|
||||
"FEC block complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("fec generate_repair error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let _ = fec_enc.finalize_block();
|
||||
block_id = block_id.wrapping_add(1);
|
||||
frame_in_block = 0;
|
||||
}
|
||||
|
||||
// Periodic stats every 5 seconds
|
||||
if last_stats_log.elapsed().as_secs() >= 5 {
|
||||
info!(
|
||||
seq = s,
|
||||
block_id,
|
||||
frames_sent,
|
||||
frames_dropped,
|
||||
send_errors,
|
||||
ring_avail = state.capture_ring.available(),
|
||||
"send stats"
|
||||
);
|
||||
last_stats_log = Instant::now();
|
||||
}
|
||||
}
|
||||
info!(frames_sent, frames_dropped, send_errors, "send task ended");
|
||||
};
|
||||
|
||||
// Pre-allocate decode buffer
|
||||
let mut decode_buf = vec![0i16; FRAME_SAMPLES];
|
||||
|
||||
// Recv task: MediaPackets → FEC decode → Opus decode → playout ring
|
||||
let recv_task = async {
|
||||
let mut frames_decoded: u64 = 0;
|
||||
let mut fec_recovered: u64 = 0;
|
||||
let mut recv_errors: u64 = 0;
|
||||
let mut last_recv_instant = Instant::now();
|
||||
let mut max_recv_gap_ms: u64 = 0;
|
||||
let mut last_stats_log = Instant::now();
|
||||
info!("recv task started (Opus + RaptorQ FEC)");
|
||||
loop {
|
||||
if !state.running.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
match transport_recv.recv_media().await {
|
||||
Ok(Some(pkt)) => {
|
||||
// Track recv gaps — large gaps indicate network or relay issues
|
||||
let recv_gap_ms = last_recv_instant.elapsed().as_millis() as u64;
|
||||
last_recv_instant = Instant::now();
|
||||
if recv_gap_ms > max_recv_gap_ms {
|
||||
max_recv_gap_ms = recv_gap_ms;
|
||||
}
|
||||
if recv_gap_ms > 500 {
|
||||
warn!(
|
||||
recv_gap_ms,
|
||||
seq = pkt.header.seq,
|
||||
is_repair = pkt.header.is_repair,
|
||||
"large recv gap — possible network stall"
|
||||
);
|
||||
}
|
||||
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let pkt_block = pkt.header.fec_block;
|
||||
let pkt_symbol = pkt.header.fec_symbol;
|
||||
|
||||
// Feed every packet (source + repair) to FEC decoder
|
||||
let _ = fec_dec.add_symbol(
|
||||
pkt_block,
|
||||
pkt_symbol,
|
||||
is_repair,
|
||||
&pkt.payload,
|
||||
);
|
||||
|
||||
// Source packets: decode directly
|
||||
if !is_repair {
|
||||
match decoder.decode(&pkt.payload, &mut decode_buf) {
|
||||
Ok(samples) => {
|
||||
playout_agc.process_frame(&mut decode_buf[..samples]);
|
||||
state.playout_ring.write(&decode_buf[..samples]);
|
||||
frames_decoded += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("opus decode error: {e}");
|
||||
if let Ok(samples) = decoder.decode_lost(&mut decode_buf) {
|
||||
playout_agc.process_frame(&mut decode_buf[..samples]);
|
||||
state.playout_ring.write(&decode_buf[..samples]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try FEC recovery
|
||||
if let Ok(Some(recovered_frames)) = fec_dec.try_decode(pkt_block) {
|
||||
fec_recovered += recovered_frames.len() as u64;
|
||||
if fec_recovered % 50 == 1 {
|
||||
info!(
|
||||
fec_recovered,
|
||||
block = pkt_block,
|
||||
frames = recovered_frames.len(),
|
||||
"FEC block recovered"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Expire old blocks to prevent memory growth
|
||||
if pkt_block > 3 {
|
||||
fec_dec.expire_before(pkt_block.wrapping_sub(3));
|
||||
}
|
||||
|
||||
let mut stats = state.stats.lock().unwrap();
|
||||
stats.frames_decoded = frames_decoded;
|
||||
stats.fec_recovered = fec_recovered;
|
||||
drop(stats);
|
||||
|
||||
// Periodic stats every 5 seconds
|
||||
if last_stats_log.elapsed().as_secs() >= 5 {
|
||||
info!(
|
||||
frames_decoded,
|
||||
fec_recovered,
|
||||
recv_errors,
|
||||
max_recv_gap_ms,
|
||||
playout_avail = state.playout_ring.available(),
|
||||
"recv stats"
|
||||
);
|
||||
max_recv_gap_ms = 0;
|
||||
last_stats_log = Instant::now();
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!(frames_decoded, fec_recovered, "relay disconnected (stream ended)");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
recv_errors += 1;
|
||||
// Transient errors: log and keep going
|
||||
let msg = e.to_string();
|
||||
if msg.contains("closed") || msg.contains("reset") {
|
||||
error!(recv_errors, "recv fatal: {e}");
|
||||
break;
|
||||
}
|
||||
// Non-fatal: log throttled
|
||||
if recv_errors <= 3 || recv_errors % 50 == 0 {
|
||||
warn!(recv_errors, "recv error (continuing): {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(frames_decoded, fec_recovered, recv_errors, "recv task ended");
|
||||
};
|
||||
|
||||
// Stats task — polls path quality + quinn RTT every 500ms
|
||||
let transport_stats = transport.clone();
|
||||
let stats_task = async {
|
||||
loop {
|
||||
if !state.running.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
// Feed quinn's QUIC-level RTT into our path monitor
|
||||
let quic_rtt_ms = transport_stats.connection().stats().path.rtt.as_millis() as u32;
|
||||
if quic_rtt_ms > 0 {
|
||||
transport_stats.feed_rtt(quic_rtt_ms);
|
||||
}
|
||||
let pq = transport_stats.path_quality();
|
||||
{
|
||||
let mut stats = state.stats.lock().unwrap();
|
||||
stats.frames_encoded = seq.load(Ordering::Relaxed) as u64;
|
||||
stats.loss_pct = pq.loss_pct;
|
||||
stats.rtt_ms = quic_rtt_ms;
|
||||
stats.jitter_ms = pq.jitter_ms;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
}
|
||||
};
|
||||
|
||||
// Signal recv task — listens for RoomUpdate and other signaling messages
|
||||
let transport_signal = transport.clone();
|
||||
let state_signal = state.clone();
|
||||
let signal_task = async {
|
||||
loop {
|
||||
match transport_signal.recv_signal().await {
|
||||
Ok(Some(SignalMessage::RoomUpdate { count, participants })) => {
|
||||
info!(count, "RoomUpdate received");
|
||||
let members: Vec<crate::stats::RoomMember> = participants
|
||||
.iter()
|
||||
.map(|p| crate::stats::RoomMember {
|
||||
fingerprint: p.fingerprint.clone(),
|
||||
alias: p.alias.clone(),
|
||||
})
|
||||
.collect();
|
||||
let mut stats = state_signal.stats.lock().unwrap();
|
||||
stats.room_participant_count = count;
|
||||
stats.room_participants = members;
|
||||
}
|
||||
Ok(Some(msg)) => {
|
||||
info!("signal received: {:?}", std::mem::discriminant(&msg));
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("signal stream closed");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("signal recv error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = send_task => info!("send task ended"),
|
||||
_ = recv_task => info!("recv task ended"),
|
||||
_ = stats_task => info!("stats task ended"),
|
||||
_ = signal_task => info!("signal task ended"),
|
||||
}
|
||||
|
||||
// Send CONNECTION_CLOSE and wait up to 500ms for the peer to acknowledge.
|
||||
// This ensures the relay sees the close even if the first packet is lost.
|
||||
info!("closing QUIC connection...");
|
||||
transport.close_now();
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(500),
|
||||
transport.connection().closed(),
|
||||
).await {
|
||||
Ok(_) => info!("QUIC connection closed cleanly"),
|
||||
Err(_) => info!("QUIC close timed out (relay may not have ack'd)"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
256
crates/wzp-android/src/jni_bridge.rs
Normal file
256
crates/wzp-android/src/jni_bridge.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
//! JNI bridge for Android — thin layer between Kotlin and the WzpEngine.
|
||||
|
||||
use std::panic;
|
||||
use std::sync::Once;
|
||||
|
||||
use jni::objects::{JClass, JObject, JString};
|
||||
use jni::sys::{jboolean, jint, jlong, jstring};
|
||||
use jni::JNIEnv;
|
||||
use tracing::{error, info};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
use crate::engine::{CallStartConfig, WzpEngine};
|
||||
|
||||
/// Opaque engine handle passed to/from Kotlin as a `jlong`.
|
||||
struct EngineHandle {
|
||||
engine: WzpEngine,
|
||||
}
|
||||
|
||||
/// Recover the `EngineHandle` from a raw handle value.
|
||||
unsafe fn handle_ref(handle: jlong) -> &'static mut EngineHandle {
|
||||
unsafe { &mut *(handle as *mut EngineHandle) }
|
||||
}
|
||||
|
||||
fn profile_from_int(value: jint) -> QualityProfile {
|
||||
match value {
|
||||
1 => QualityProfile::DEGRADED,
|
||||
2 => QualityProfile::CATASTROPHIC,
|
||||
_ => QualityProfile::GOOD,
|
||||
}
|
||||
}
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize tracing → Android logcat (tag "wzp_android").
|
||||
/// Safe to call multiple times — only the first call takes effect.
|
||||
fn init_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
if let Ok(layer) = tracing_android::layer("wzp_android") {
|
||||
let _ = tracing_subscriber::registry().with(layer).try_init();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeInit(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
) -> jlong {
|
||||
let result = panic::catch_unwind(|| {
|
||||
init_logging();
|
||||
let handle = Box::new(EngineHandle {
|
||||
engine: WzpEngine::new(),
|
||||
});
|
||||
Box::into_raw(handle) as jlong
|
||||
});
|
||||
match result {
|
||||
Ok(h) => h,
|
||||
Err(_) => 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
||||
mut env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
relay_addr_j: JString,
|
||||
room_j: JString,
|
||||
seed_hex_j: JString,
|
||||
token_j: JString,
|
||||
alias_j: JString,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
||||
let room: String = env.get_string(&room_j).map(|s| s.into()).unwrap_or_default();
|
||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
||||
let alias: String = env.get_string(&alias_j).map(|s| s.into()).unwrap_or_default();
|
||||
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
|
||||
// Parse hex seed
|
||||
let mut identity_seed = [0u8; 32];
|
||||
if seed_hex.len() == 64 {
|
||||
for i in 0..32 {
|
||||
if let Ok(byte) = u8::from_str_radix(&seed_hex[i * 2..i * 2 + 2], 16) {
|
||||
identity_seed[i] = byte;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Generate random seed if not provided
|
||||
use rand::RngCore;
|
||||
rand::thread_rng().fill_bytes(&mut identity_seed);
|
||||
}
|
||||
|
||||
let config = CallStartConfig {
|
||||
profile: QualityProfile::GOOD,
|
||||
relay_addr,
|
||||
room,
|
||||
auth_token: if token.is_empty() { Vec::new() } else { token.into_bytes() },
|
||||
identity_seed,
|
||||
alias: if alias.is_empty() { None } else { Some(alias) },
|
||||
};
|
||||
|
||||
match h.engine.start_call(config) {
|
||||
Ok(()) => 0,
|
||||
Err(e) => {
|
||||
error!("start_call failed: {e}");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok(code) => code,
|
||||
Err(_) => -1,
|
||||
}
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStopCall(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
h.engine.stop_call();
|
||||
}));
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeSetMute(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
muted: jboolean,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
h.engine.set_mute(muted != 0);
|
||||
}));
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeSetSpeaker(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
speaker: jboolean,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
h.engine.set_speaker(speaker != 0);
|
||||
}));
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeGetStats<'a>(
|
||||
mut env: JNIEnv<'a>,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
) -> jstring {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let stats = h.engine.get_stats();
|
||||
serde_json::to_string(&stats).unwrap_or_else(|_| "{}".to_string())
|
||||
}));
|
||||
|
||||
let json = match result {
|
||||
Ok(s) => s,
|
||||
Err(_) => "{}".to_string(),
|
||||
};
|
||||
|
||||
env.new_string(&json)
|
||||
.map(|s| s.into_raw())
|
||||
.unwrap_or(JObject::null().into_raw())
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeForceProfile(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
profile: jint,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let qp = profile_from_int(profile);
|
||||
h.engine.force_profile(qp);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Write captured PCM samples from Kotlin AudioRecord into the engine's capture ring.
|
||||
/// pcm is a Java short[] array.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudio(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
pcm: jni::objects::JShortArray,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let len = env.get_array_length(&pcm).unwrap_or(0) as usize;
|
||||
if len == 0 {
|
||||
return 0;
|
||||
}
|
||||
let mut buf = vec![0i16; len];
|
||||
// GetShortArrayRegion copies Java array into our buffer
|
||||
if env.get_short_array_region(&pcm, 0, &mut buf).is_err() {
|
||||
return 0;
|
||||
}
|
||||
h.engine.write_audio(&buf) as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Read decoded PCM samples from the engine's playout ring for Kotlin AudioTrack.
|
||||
/// pcm is a Java short[] array to fill. Returns number of samples actually read.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudio(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
pcm: jni::objects::JShortArray,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let len = env.get_array_length(&pcm).unwrap_or(0) as usize;
|
||||
if len == 0 {
|
||||
return 0;
|
||||
}
|
||||
let mut buf = vec![0i16; len];
|
||||
let read = h.engine.read_audio(&mut buf);
|
||||
if read > 0 {
|
||||
let _ = env.set_short_array_region(&pcm, 0, &buf[..read]);
|
||||
}
|
||||
read as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeDestroy(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { Box::from_raw(handle as *mut EngineHandle) };
|
||||
drop(h);
|
||||
}));
|
||||
}
|
||||
18
crates/wzp-android/src/lib.rs
Normal file
18
crates/wzp-android/src/lib.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! WarzonePhone Android native VoIP engine.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Oboe audio backend with lock-free SPSC ring buffers
|
||||
//! - Engine orchestrator managing call lifecycle
|
||||
//! - Codec pipeline thread (encode/decode/FEC/jitter)
|
||||
//! - Call statistics and command interface
|
||||
//!
|
||||
//! On non-Android targets, the Oboe C++ layer compiles as a stub,
|
||||
//! allowing `cargo check` and unit tests on the host.
|
||||
|
||||
pub mod audio_android;
|
||||
pub mod audio_ring;
|
||||
pub mod commands;
|
||||
pub mod engine;
|
||||
pub mod pipeline;
|
||||
pub mod stats;
|
||||
pub mod jni_bridge;
|
||||
262
crates/wzp-android/src/pipeline.rs
Normal file
262
crates/wzp-android/src/pipeline.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
//! Codec pipeline — encode/decode with FEC and jitter buffer.
|
||||
//!
|
||||
//! Runs on a dedicated thread, processing 20 ms frames at 48 kHz.
|
||||
//! The pipeline is NOT Send/Sync (Opus encoder state) — it is owned
|
||||
//! exclusively by the codec thread.
|
||||
|
||||
use tracing::{debug, warn};
|
||||
use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder, AutoGainControl, EchoCanceller};
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::traits::QualityController;
|
||||
use wzp_proto::{MediaPacket, QualityProfile};
|
||||
|
||||
use crate::audio_android::FRAME_SAMPLES;
|
||||
|
||||
/// Maximum encoded frame size (Opus worst case at highest bitrate).
|
||||
const MAX_ENCODED_BYTES: usize = 1275;
|
||||
|
||||
/// Pipeline statistics snapshot.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct PipelineStats {
|
||||
pub frames_encoded: u64,
|
||||
pub frames_decoded: u64,
|
||||
pub underruns: u64,
|
||||
pub jitter_depth: usize,
|
||||
pub quality_tier: u8,
|
||||
}
|
||||
|
||||
/// The codec pipeline: encode, FEC, jitter buffer, decode.
|
||||
///
|
||||
/// This struct is owned by the codec thread and not shared.
|
||||
pub struct Pipeline {
|
||||
encoder: AdaptiveEncoder,
|
||||
decoder: AdaptiveDecoder,
|
||||
fec_encoder: RaptorQFecEncoder,
|
||||
fec_decoder: RaptorQFecDecoder,
|
||||
jitter_buffer: JitterBuffer,
|
||||
quality_ctrl: AdaptiveQualityController,
|
||||
/// Acoustic echo canceller applied before encoding.
|
||||
aec: EchoCanceller,
|
||||
/// Automatic gain control applied before encoding.
|
||||
agc: AutoGainControl,
|
||||
/// Last decoded PCM frame, used as the AEC far-end reference.
|
||||
last_decoded_farend: Option<Vec<i16>>,
|
||||
// Pre-allocated scratch buffers
|
||||
capture_buf: Vec<i16>,
|
||||
#[allow(dead_code)]
|
||||
playout_buf: Vec<i16>,
|
||||
encode_out: Vec<u8>,
|
||||
// Stats counters
|
||||
frames_encoded: u64,
|
||||
frames_decoded: u64,
|
||||
underruns: u64,
|
||||
}
|
||||
|
||||
impl Pipeline {
|
||||
/// Create a new pipeline configured for the given quality profile.
|
||||
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
|
||||
let encoder = AdaptiveEncoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||
let decoder = AdaptiveDecoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||
let fec_encoder =
|
||||
RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||
let fec_decoder =
|
||||
RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||
let jitter_buffer = JitterBuffer::new(10, 250, 3);
|
||||
let quality_ctrl = AdaptiveQualityController::new();
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
fec_encoder,
|
||||
fec_decoder,
|
||||
jitter_buffer,
|
||||
quality_ctrl,
|
||||
aec: EchoCanceller::new(48000, 100), // 100 ms echo tail
|
||||
agc: AutoGainControl::new(),
|
||||
last_decoded_farend: None,
|
||||
capture_buf: vec![0i16; FRAME_SAMPLES],
|
||||
playout_buf: vec![0i16; FRAME_SAMPLES],
|
||||
encode_out: vec![0u8; MAX_ENCODED_BYTES],
|
||||
frames_encoded: 0,
|
||||
frames_decoded: 0,
|
||||
underruns: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode a PCM frame into a compressed packet.
|
||||
///
|
||||
/// If `muted` is true, a silence frame is encoded (all zeros).
|
||||
/// Returns the encoded bytes, or `None` on encoder error.
|
||||
pub fn encode_frame(&mut self, pcm: &[i16], muted: bool) -> Option<Vec<u8>> {
|
||||
let input = if muted {
|
||||
// Zero the capture buffer for silence
|
||||
for s in self.capture_buf.iter_mut() {
|
||||
*s = 0;
|
||||
}
|
||||
&self.capture_buf[..]
|
||||
} else {
|
||||
// Feed the last decoded playout as AEC far-end reference.
|
||||
if let Some(ref farend) = self.last_decoded_farend {
|
||||
self.aec.feed_farend(farend);
|
||||
}
|
||||
|
||||
// Apply AEC + AGC to the captured PCM.
|
||||
let len = pcm.len().min(self.capture_buf.len());
|
||||
self.capture_buf[..len].copy_from_slice(&pcm[..len]);
|
||||
self.aec.process_frame(&mut self.capture_buf[..len]);
|
||||
self.agc.process_frame(&mut self.capture_buf[..len]);
|
||||
&self.capture_buf[..len]
|
||||
};
|
||||
|
||||
match self.encoder.encode(input, &mut self.encode_out) {
|
||||
Ok(n) => {
|
||||
self.frames_encoded += 1;
|
||||
let encoded = self.encode_out[..n].to_vec();
|
||||
|
||||
// Feed into FEC encoder
|
||||
if let Err(e) = self.fec_encoder.add_source_symbol(&encoded) {
|
||||
warn!("FEC encode error: {e}");
|
||||
}
|
||||
|
||||
Some(encoded)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("encode error: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed a received media packet into the jitter buffer.
|
||||
pub fn feed_packet(&mut self, packet: MediaPacket) {
|
||||
// Feed FEC symbols if present
|
||||
let header = &packet.header;
|
||||
if header.fec_block != 0 || header.fec_symbol != 0 {
|
||||
let is_repair = header.is_repair;
|
||||
if let Err(e) = self.fec_decoder.add_symbol(
|
||||
header.fec_block,
|
||||
header.fec_symbol,
|
||||
is_repair,
|
||||
&packet.payload,
|
||||
) {
|
||||
debug!("FEC symbol feed error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
self.jitter_buffer.push(packet);
|
||||
}
|
||||
|
||||
/// Decode the next frame from the jitter buffer.
|
||||
///
|
||||
/// Returns decoded PCM samples, or `None` if the buffer is not ready.
|
||||
/// Decoded PCM is also stored as the AEC far-end reference for the next
|
||||
/// encode cycle.
|
||||
pub fn decode_frame(&mut self) -> Option<Vec<i16>> {
|
||||
let result = match self.jitter_buffer.pop() {
|
||||
PlayoutResult::Packet(pkt) => {
|
||||
let mut pcm = vec![0i16; FRAME_SAMPLES];
|
||||
match self.decoder.decode(&pkt.payload, &mut pcm) {
|
||||
Ok(n) => {
|
||||
self.frames_decoded += 1;
|
||||
pcm.truncate(n);
|
||||
Some(pcm)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("decode error: {e}");
|
||||
// Attempt PLC
|
||||
self.generate_plc()
|
||||
}
|
||||
}
|
||||
}
|
||||
PlayoutResult::Missing { seq } => {
|
||||
debug!(seq, "jitter buffer: missing packet, generating PLC");
|
||||
self.generate_plc()
|
||||
}
|
||||
PlayoutResult::NotReady => {
|
||||
self.underruns += 1;
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Save decoded PCM as far-end reference for AEC.
|
||||
if let Some(ref pcm) = result {
|
||||
self.last_decoded_farend = Some(pcm.clone());
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Generate packet loss concealment output.
|
||||
fn generate_plc(&mut self) -> Option<Vec<i16>> {
|
||||
let mut pcm = vec![0i16; FRAME_SAMPLES];
|
||||
match self.decoder.decode_lost(&mut pcm) {
|
||||
Ok(n) => {
|
||||
self.frames_decoded += 1;
|
||||
pcm.truncate(n);
|
||||
Some(pcm)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("PLC error: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed a quality report into the adaptive quality controller.
|
||||
///
|
||||
/// Returns a new profile if a tier transition occurred.
|
||||
#[allow(unused)]
|
||||
pub fn observe_quality(
|
||||
&mut self,
|
||||
report: &wzp_proto::QualityReport,
|
||||
) -> Option<QualityProfile> {
|
||||
let new_profile = self.quality_ctrl.observe(report);
|
||||
if let Some(ref profile) = new_profile {
|
||||
if let Err(e) = self.encoder.set_profile(*profile) {
|
||||
warn!("encoder set_profile error: {e}");
|
||||
}
|
||||
if let Err(e) = self.decoder.set_profile(*profile) {
|
||||
warn!("decoder set_profile error: {e}");
|
||||
}
|
||||
}
|
||||
new_profile
|
||||
}
|
||||
|
||||
/// Force a specific quality profile.
|
||||
#[allow(unused)]
|
||||
pub fn force_profile(&mut self, profile: QualityProfile) {
|
||||
self.quality_ctrl.force_profile(profile);
|
||||
if let Err(e) = self.encoder.set_profile(profile) {
|
||||
warn!("encoder set_profile error: {e}");
|
||||
}
|
||||
if let Err(e) = self.decoder.set_profile(profile) {
|
||||
warn!("decoder set_profile error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current pipeline statistics.
|
||||
pub fn stats(&self) -> PipelineStats {
|
||||
PipelineStats {
|
||||
frames_encoded: self.frames_encoded,
|
||||
frames_decoded: self.frames_decoded,
|
||||
underruns: self.underruns,
|
||||
jitter_depth: self.jitter_buffer.stats().current_depth,
|
||||
quality_tier: self.quality_ctrl.tier() as u8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable or disable acoustic echo cancellation.
|
||||
pub fn set_aec_enabled(&mut self, enabled: bool) {
|
||||
self.aec.set_enabled(enabled);
|
||||
}
|
||||
|
||||
/// Enable or disable automatic gain control.
|
||||
pub fn set_agc_enabled(&mut self, enabled: bool) {
|
||||
self.agc.set_enabled(enabled);
|
||||
}
|
||||
}
|
||||
67
crates/wzp-android/src/stats.rs
Normal file
67
crates/wzp-android/src/stats.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Call statistics for the Android engine.
|
||||
|
||||
/// State of the call.
|
||||
/// Serializes as integer for easy parsing on the Kotlin side:
|
||||
/// 0=Idle, 1=Connecting, 2=Active, 3=Reconnecting, 4=Closed
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub enum CallState {
|
||||
#[default]
|
||||
Idle,
|
||||
Connecting,
|
||||
Active,
|
||||
Reconnecting,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl serde::Serialize for CallState {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let n: u8 = match self {
|
||||
CallState::Idle => 0,
|
||||
CallState::Connecting => 1,
|
||||
CallState::Active => 2,
|
||||
CallState::Reconnecting => 3,
|
||||
CallState::Closed => 4,
|
||||
};
|
||||
serializer.serialize_u8(n)
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregated call statistics, serializable for JNI bridge.
|
||||
#[derive(Clone, Debug, Default, serde::Serialize)]
|
||||
pub struct CallStats {
|
||||
/// Current call state.
|
||||
pub state: CallState,
|
||||
/// Call duration in seconds.
|
||||
pub duration_secs: f64,
|
||||
/// Current quality tier (0=GOOD, 1=DEGRADED, 2=CATASTROPHIC).
|
||||
pub quality_tier: u8,
|
||||
/// Observed packet loss percentage.
|
||||
pub loss_pct: f32,
|
||||
/// Smoothed round-trip time in milliseconds.
|
||||
pub rtt_ms: u32,
|
||||
/// Jitter in milliseconds.
|
||||
pub jitter_ms: u32,
|
||||
/// Current jitter buffer depth in packets.
|
||||
pub jitter_buffer_depth: usize,
|
||||
/// Total frames encoded since call start.
|
||||
pub frames_encoded: u64,
|
||||
/// Total frames decoded since call start.
|
||||
pub frames_decoded: u64,
|
||||
/// Number of playout underruns (buffer empty when audio needed).
|
||||
pub underruns: u64,
|
||||
/// Frames recovered by FEC.
|
||||
pub fec_recovered: u64,
|
||||
/// Current mic audio level (RMS of i16 samples, 0-32767).
|
||||
pub audio_level: u32,
|
||||
/// Number of participants in the room (from last RoomUpdate).
|
||||
pub room_participant_count: u32,
|
||||
/// Participant list (fingerprint + optional alias) serialized as JSON array.
|
||||
pub room_participants: Vec<RoomMember>,
|
||||
}
|
||||
|
||||
/// A room member entry, serialized into the stats JSON.
|
||||
#[derive(Clone, Debug, Default, serde::Serialize)]
|
||||
pub struct RoomMember {
|
||||
pub fingerprint: String,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
@@ -18,11 +18,18 @@ tracing-subscriber = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
anyhow = "1"
|
||||
serde = { workspace = true }
|
||||
serde_json = "1"
|
||||
chrono = "0.4"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
cpal = { version = "0.15", optional = true }
|
||||
coreaudio-rs = { version = "0.11", optional = true }
|
||||
libc = "0.2"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
audio = ["cpal"]
|
||||
vpio = ["coreaudio-rs"]
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-client"
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
//! Both structs use 48 kHz, mono, i16 format to match the WarzonePhone codec
|
||||
//! pipeline. Frames are 960 samples (20 ms at 48 kHz).
|
||||
//!
|
||||
//! The cpal `Stream` type is not `Send`, so each struct spawns a dedicated OS
|
||||
//! thread that owns the stream. The public API exposes only `Send + Sync`
|
||||
//! channel handles.
|
||||
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
|
||||
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
@@ -16,6 +14,8 @@ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// Number of samples per 20 ms frame at 48 kHz mono.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
@@ -23,22 +23,24 @@ pub const FRAME_SAMPLES: usize = 960;
|
||||
// AudioCapture
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Captures microphone input and yields 960-sample PCM frames.
|
||||
/// Captures microphone input via CPAL and writes PCM into a lock-free ring buffer.
|
||||
///
|
||||
/// The cpal stream lives on a dedicated OS thread; this handle is `Send + Sync`.
|
||||
pub struct AudioCapture {
|
||||
rx: mpsc::Receiver<Vec<i16>>,
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioCapture {
|
||||
/// Create and start capturing from the default input device at 48 kHz mono.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(64);
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let (init_tx, init_rx) = mpsc::sync_channel::<Result<(), String>>(1);
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-capture".into())
|
||||
@@ -59,53 +61,51 @@ impl AudioCapture {
|
||||
|
||||
let use_f32 = !supports_i16_input(&device)?;
|
||||
|
||||
let buf = Arc::new(std::sync::Mutex::new(
|
||||
Vec::<i16>::with_capacity(FRAME_SAMPLES),
|
||||
));
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("input stream error: {e}");
|
||||
};
|
||||
|
||||
let logged_cb_size = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let stream = if use_f32 {
|
||||
let buf = buf.clone();
|
||||
let tx = tx.clone();
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let logged = logged_cb_size.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lock = buf.lock().unwrap();
|
||||
for &s in data {
|
||||
lock.push(f32_to_i16(s));
|
||||
if lock.len() == FRAME_SAMPLES {
|
||||
let frame = lock.drain(..).collect();
|
||||
let _ = tx.try_send(frame);
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[audio] capture callback: {} f32 samples", data.len());
|
||||
}
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in data.chunks(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
for i in 0..n {
|
||||
tmp[i] = f32_to_i16(chunk[i]);
|
||||
}
|
||||
ring.write(&tmp[..n]);
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let buf = buf.clone();
|
||||
let tx = tx.clone();
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let logged = logged_cb_size.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lock = buf.lock().unwrap();
|
||||
for &s in data {
|
||||
lock.push(s);
|
||||
if lock.len() == FRAME_SAMPLES {
|
||||
let frame = lock.drain(..).collect();
|
||||
let _ = tx.try_send(frame);
|
||||
}
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[audio] capture callback: {} i16 samples", data.len());
|
||||
}
|
||||
ring.write(data);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
@@ -114,7 +114,6 @@ impl AudioCapture {
|
||||
|
||||
stream.play().context("failed to start input stream")?;
|
||||
|
||||
// Signal success to the caller before parking.
|
||||
let _ = init_tx.send(Ok(()));
|
||||
|
||||
// Keep stream alive until stopped.
|
||||
@@ -135,15 +134,12 @@ impl AudioCapture {
|
||||
.map_err(|_| anyhow!("capture thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { rx, running })
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
/// Read the next frame of 960 PCM samples (blocking until available).
|
||||
///
|
||||
/// Returns `None` when the stream has been stopped or the channel is
|
||||
/// disconnected.
|
||||
pub fn read_frame(&self) -> Option<Vec<i16>> {
|
||||
self.rx.recv().ok()
|
||||
/// Get a reference to the capture ring buffer for direct polling.
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
/// Stop capturing.
|
||||
@@ -152,26 +148,34 @@ impl AudioCapture {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AudioCapture {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AudioPlayback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Plays PCM frames through the default output device at 48 kHz mono.
|
||||
/// Plays PCM through the default output device, reading from a lock-free ring buffer.
|
||||
///
|
||||
/// The cpal stream lives on a dedicated OS thread; this handle is `Send + Sync`.
|
||||
pub struct AudioPlayback {
|
||||
tx: mpsc::SyncSender<Vec<i16>>,
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioPlayback {
|
||||
/// Create and start playback on the default output device at 48 kHz mono.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(64);
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let (init_tx, init_rx) = mpsc::sync_channel::<Result<(), String>>(1);
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-playback".into())
|
||||
@@ -192,62 +196,40 @@ impl AudioPlayback {
|
||||
|
||||
let use_f32 = !supports_i16_output(&device)?;
|
||||
|
||||
// Shared ring of samples the cpal callback drains from.
|
||||
let ring = Arc::new(std::sync::Mutex::new(
|
||||
std::collections::VecDeque::<i16>::with_capacity(FRAME_SAMPLES * 8),
|
||||
));
|
||||
|
||||
// Background drainer: moves frames from the mpsc channel into the ring.
|
||||
{
|
||||
let ring = ring.clone();
|
||||
let running = running_clone.clone();
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-playback-drain".into())
|
||||
.spawn(move || {
|
||||
while running.load(Ordering::Relaxed) {
|
||||
match rx.recv_timeout(std::time::Duration::from_millis(100)) {
|
||||
Ok(frame) => {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
lock.extend(frame);
|
||||
while lock.len() > FRAME_SAMPLES * 16 {
|
||||
lock.pop_front();
|
||||
}
|
||||
}
|
||||
Err(mpsc::RecvTimeoutError::Timeout) => {}
|
||||
Err(mpsc::RecvTimeoutError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("output stream error: {e}");
|
||||
};
|
||||
|
||||
let stream = if use_f32 {
|
||||
let ring = ring.clone();
|
||||
let ring = ring_cb.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
for sample in data.iter_mut() {
|
||||
*sample = match lock.pop_front() {
|
||||
Some(s) => i16_to_f32(s),
|
||||
None => 0.0,
|
||||
};
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in data.chunks_mut(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
let read = ring.read(&mut tmp[..n]);
|
||||
for i in 0..read {
|
||||
chunk[i] = i16_to_f32(tmp[i]);
|
||||
}
|
||||
// Fill remainder with silence if ring underran
|
||||
for i in read..n {
|
||||
chunk[i] = 0.0;
|
||||
}
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let ring = ring.clone();
|
||||
let ring = ring_cb.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [i16], _: &cpal::OutputCallbackInfo| {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
for sample in data.iter_mut() {
|
||||
*sample = lock.pop_front().unwrap_or(0);
|
||||
let read = ring.read(data);
|
||||
// Fill remainder with silence if ring underran
|
||||
for sample in &mut data[read..] {
|
||||
*sample = 0;
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
@@ -257,7 +239,6 @@ impl AudioPlayback {
|
||||
|
||||
stream.play().context("failed to start output stream")?;
|
||||
|
||||
// Signal success to the caller before parking.
|
||||
let _ = init_tx.send(Ok(()));
|
||||
|
||||
// Keep stream alive until stopped.
|
||||
@@ -278,12 +259,12 @@ impl AudioPlayback {
|
||||
.map_err(|_| anyhow!("playback thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { tx, running })
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
/// Write a frame of PCM samples for playback.
|
||||
pub fn write_frame(&self, pcm: &[i16]) {
|
||||
let _ = self.tx.try_send(pcm.to_vec());
|
||||
/// Get a reference to the playout ring buffer for direct writing.
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
/// Stop playback.
|
||||
@@ -292,11 +273,16 @@ impl AudioPlayback {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AudioPlayback {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Check if the input device supports i16 at 48 kHz mono.
|
||||
fn supports_i16_input(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_input_configs()
|
||||
@@ -313,7 +299,6 @@ fn supports_i16_input(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Check if the output device supports i16 at 48 kHz mono.
|
||||
fn supports_i16_output(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_output_configs()
|
||||
|
||||
122
crates/wzp-client/src/audio_ring.rs
Normal file
122
crates/wzp-client/src/audio_ring.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
//! Lock-free SPSC ring buffer — "Reader-Detects-Lap" architecture.
|
||||
//!
|
||||
//! SPSC invariant: the producer ONLY writes `write_pos`, the consumer
|
||||
//! ONLY writes `read_pos`. Neither thread touches the other's cursor.
|
||||
//!
|
||||
//! On overflow (writer laps the reader), the writer simply overwrites
|
||||
//! old buffer data. The reader detects the lap via `available() >
|
||||
//! RING_CAPACITY` and snaps its own `read_pos` forward.
|
||||
//!
|
||||
//! Capacity is a power of 2 for bitmask indexing (no modulo).
|
||||
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
|
||||
/// Ring buffer capacity — power of 2 for bitmask indexing.
|
||||
/// 16384 samples = 341.3ms at 48kHz mono.
|
||||
const RING_CAPACITY: usize = 16384; // 2^14
|
||||
const RING_MASK: usize = RING_CAPACITY - 1;
|
||||
|
||||
/// Lock-free single-producer single-consumer ring buffer for i16 PCM samples.
|
||||
pub struct AudioRing {
|
||||
buf: Box<[i16]>,
|
||||
/// Monotonically increasing write cursor. ONLY written by producer.
|
||||
write_pos: AtomicUsize,
|
||||
/// Monotonically increasing read cursor. ONLY written by consumer.
|
||||
read_pos: AtomicUsize,
|
||||
/// Incremented by reader when it detects it was lapped (overflow).
|
||||
overflow_count: AtomicU64,
|
||||
/// Incremented by reader when ring is empty (underrun).
|
||||
underrun_count: AtomicU64,
|
||||
}
|
||||
|
||||
// SAFETY: AudioRing is SPSC — one thread writes (producer), one reads (consumer).
|
||||
// The producer only writes write_pos. The consumer only writes read_pos.
|
||||
// Neither thread writes the other's cursor. Buffer indices are derived from
|
||||
// the owning thread's cursor, ensuring no concurrent access to the same index.
|
||||
unsafe impl Send for AudioRing {}
|
||||
unsafe impl Sync for AudioRing {}
|
||||
|
||||
impl AudioRing {
|
||||
pub fn new() -> Self {
|
||||
debug_assert!(RING_CAPACITY.is_power_of_two());
|
||||
Self {
|
||||
buf: vec![0i16; RING_CAPACITY].into_boxed_slice(),
|
||||
write_pos: AtomicUsize::new(0),
|
||||
read_pos: AtomicUsize::new(0),
|
||||
overflow_count: AtomicU64::new(0),
|
||||
underrun_count: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples available to read (clamped to capacity).
|
||||
pub fn available(&self) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let r = self.read_pos.load(Ordering::Relaxed);
|
||||
w.wrapping_sub(r).min(RING_CAPACITY)
|
||||
}
|
||||
|
||||
/// Write samples into the ring. Returns number of samples written.
|
||||
///
|
||||
/// If the ring is full, old data is silently overwritten. The reader
|
||||
/// will detect the lap and self-correct. The writer NEVER touches
|
||||
/// `read_pos`.
|
||||
pub fn write(&self, samples: &[i16]) -> usize {
|
||||
let count = samples.len().min(RING_CAPACITY);
|
||||
let w = self.write_pos.load(Ordering::Relaxed);
|
||||
|
||||
for i in 0..count {
|
||||
unsafe {
|
||||
let ptr = self.buf.as_ptr() as *mut i16;
|
||||
*ptr.add((w + i) & RING_MASK) = samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
self.write_pos
|
||||
.store(w.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Read samples from the ring into `out`. Returns number of samples read.
|
||||
///
|
||||
/// If the writer has lapped the reader (overflow), `read_pos` is snapped
|
||||
/// forward to the oldest valid data.
|
||||
pub fn read(&self, out: &mut [i16]) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let mut r = self.read_pos.load(Ordering::Relaxed);
|
||||
|
||||
let mut avail = w.wrapping_sub(r);
|
||||
|
||||
// Lap detection: writer has overwritten our unread data.
|
||||
if avail > RING_CAPACITY {
|
||||
r = w.wrapping_sub(RING_CAPACITY);
|
||||
avail = RING_CAPACITY;
|
||||
self.overflow_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let count = out.len().min(avail);
|
||||
if count == 0 {
|
||||
if w == r {
|
||||
self.underrun_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
for i in 0..count {
|
||||
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
||||
}
|
||||
|
||||
self.read_pos
|
||||
.store(r.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Number of overflow events (reader was lapped by writer).
|
||||
pub fn overflow_count(&self) -> u64 {
|
||||
self.overflow_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Number of underrun events (reader found empty buffer).
|
||||
pub fn underrun_count(&self) -> u64 {
|
||||
self.underrun_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
179
crates/wzp-client/src/audio_vpio.rs
Normal file
179
crates/wzp-client/src/audio_vpio.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
//! macOS Voice Processing I/O — uses Apple's VoiceProcessingIO audio unit
|
||||
//! for hardware-accelerated echo cancellation, AGC, and noise suppression.
|
||||
//!
|
||||
//! VoiceProcessingIO is a combined input+output unit that knows what's going
|
||||
//! to the speaker, so it can cancel the echo from the mic signal internally.
|
||||
//! This is the same engine FaceTime and other Apple apps use.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use coreaudio::audio_unit::audio_format::LinearPcmFlags;
|
||||
use coreaudio::audio_unit::render_callback::{self, data};
|
||||
use coreaudio::audio_unit::{AudioUnit, Element, IOType, SampleFormat, Scope, StreamFormat};
|
||||
use coreaudio::sys;
|
||||
use tracing::info;
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// Number of samples per 20 ms frame at 48 kHz mono.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
/// Combined capture + playback via macOS VoiceProcessingIO.
|
||||
///
|
||||
/// The OS handles AEC internally — no manual far-end feeding needed.
|
||||
pub struct VpioAudio {
|
||||
capture_ring: Arc<AudioRing>,
|
||||
playout_ring: Arc<AudioRing>,
|
||||
_audio_unit: AudioUnit,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl VpioAudio {
|
||||
/// Start VoiceProcessingIO with AEC enabled.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let capture_ring = Arc::new(AudioRing::new());
|
||||
let playout_ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
|
||||
.context("failed to create VoiceProcessingIO audio unit")?;
|
||||
|
||||
// Must uninitialize before configuring properties.
|
||||
au.uninitialize()
|
||||
.context("failed to uninitialize VPIO for configuration")?;
|
||||
|
||||
// Enable input (mic) on Element::Input (bus 1).
|
||||
let enable: u32 = 1;
|
||||
au.set_property(
|
||||
sys::kAudioOutputUnitProperty_EnableIO,
|
||||
Scope::Input,
|
||||
Element::Input,
|
||||
Some(&enable),
|
||||
)
|
||||
.context("failed to enable VPIO input")?;
|
||||
|
||||
// Output (speaker) is enabled by default on VPIO, but be explicit.
|
||||
au.set_property(
|
||||
sys::kAudioOutputUnitProperty_EnableIO,
|
||||
Scope::Output,
|
||||
Element::Output,
|
||||
Some(&enable),
|
||||
)
|
||||
.context("failed to enable VPIO output")?;
|
||||
|
||||
// Configure stream format: 48kHz mono f32 non-interleaved
|
||||
let stream_format = StreamFormat {
|
||||
sample_rate: 48_000.0,
|
||||
sample_format: SampleFormat::F32,
|
||||
flags: LinearPcmFlags::IS_FLOAT
|
||||
| LinearPcmFlags::IS_PACKED
|
||||
| LinearPcmFlags::IS_NON_INTERLEAVED,
|
||||
channels: 1,
|
||||
};
|
||||
|
||||
let asbd = stream_format.to_asbd();
|
||||
|
||||
// Input: set format on Output scope of Input element
|
||||
// (= the format the AU delivers to us from the mic)
|
||||
au.set_property(
|
||||
sys::kAudioUnitProperty_StreamFormat,
|
||||
Scope::Output,
|
||||
Element::Input,
|
||||
Some(&asbd),
|
||||
)
|
||||
.context("failed to set input stream format")?;
|
||||
|
||||
// Output: set format on Input scope of Output element
|
||||
// (= the format we feed to the AU for the speaker)
|
||||
au.set_property(
|
||||
sys::kAudioUnitProperty_StreamFormat,
|
||||
Scope::Input,
|
||||
Element::Output,
|
||||
Some(&asbd),
|
||||
)
|
||||
.context("failed to set output stream format")?;
|
||||
|
||||
// Set up input callback (mic capture with AEC applied)
|
||||
let cap_ring = capture_ring.clone();
|
||||
let cap_running = running.clone();
|
||||
let logged = Arc::new(AtomicBool::new(false));
|
||||
au.set_input_callback(
|
||||
move |args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
if !cap_running.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
let mut buffers = args.data.channels();
|
||||
if let Some(ch) = buffers.next() {
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
|
||||
}
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in ch.chunks(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
for i in 0..n {
|
||||
tmp[i] = (chunk[i].clamp(-1.0, 1.0) * i16::MAX as f32) as i16;
|
||||
}
|
||||
cap_ring.write(&tmp[..n]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.context("failed to set input callback")?;
|
||||
|
||||
// Set up output callback (speaker playback — AEC uses this as reference)
|
||||
let play_ring = playout_ring.clone();
|
||||
au.set_render_callback(
|
||||
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
let mut buffers = args.data.channels_mut();
|
||||
if let Some(ch) = buffers.next() {
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in ch.chunks_mut(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
let read = play_ring.read(&mut tmp[..n]);
|
||||
for i in 0..read {
|
||||
chunk[i] = tmp[i] as f32 / i16::MAX as f32;
|
||||
}
|
||||
for i in read..n {
|
||||
chunk[i] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.context("failed to set render callback")?;
|
||||
|
||||
au.initialize().context("failed to initialize VoiceProcessingIO")?;
|
||||
au.start().context("failed to start VoiceProcessingIO")?;
|
||||
|
||||
info!("VoiceProcessingIO started (OS-level AEC enabled)");
|
||||
|
||||
Ok(Self {
|
||||
capture_ring,
|
||||
playout_ring,
|
||||
_audio_unit: au,
|
||||
running,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn capture_ring(&self) -> &Arc<AudioRing> {
|
||||
&self.capture_ring
|
||||
}
|
||||
|
||||
pub fn playout_ring(&self) -> &Arc<AudioRing> {
|
||||
&self.playout_ring
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for VpioAudio {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
@@ -2,17 +2,21 @@
|
||||
//!
|
||||
//! Pipeline: mic → encode → FEC → encrypt → send / recv → decrypt → FEC → decode → speaker
|
||||
|
||||
use bytes::Bytes;
|
||||
use tracing::{debug, warn};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use wzp_codec::{AutoGainControl, ComfortNoise, EchoCanceller, NoiseSupressor, SilenceDetector};
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::packet::{MediaHeader, MediaPacket};
|
||||
use wzp_proto::packet::{MediaHeader, MediaPacket, MiniFrameContext};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{
|
||||
AudioDecoder, AudioEncoder, FecDecoder, FecEncoder,
|
||||
};
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::packet::QualityReport;
|
||||
use wzp_proto::{CodecId, QualityProfile};
|
||||
|
||||
/// Configuration for a call session.
|
||||
pub struct CallConfig {
|
||||
@@ -24,6 +28,28 @@ pub struct CallConfig {
|
||||
pub jitter_max: usize,
|
||||
/// Jitter buffer min depth before playout.
|
||||
pub jitter_min: usize,
|
||||
/// Enable silence suppression (default: true).
|
||||
pub suppression_enabled: bool,
|
||||
/// RMS threshold for silence detection (default: 100.0 for i16 PCM).
|
||||
pub silence_threshold_rms: f64,
|
||||
/// Hangover frames before suppression begins (default: 5 = 100ms at 20ms frames).
|
||||
pub silence_hangover_frames: u32,
|
||||
/// Comfort noise amplitude (default: 50).
|
||||
pub comfort_noise_level: i16,
|
||||
/// Enable ML-based noise suppression via RNNoise (default: true).
|
||||
pub noise_suppression: bool,
|
||||
/// Enable mini-frame header compression (default: true).
|
||||
/// When enabled, only every 50th frame carries a full 12-byte MediaHeader;
|
||||
/// intermediate frames use a compact 4-byte MiniHeader.
|
||||
pub mini_frames_enabled: bool,
|
||||
/// AEC far-end delay compensation in milliseconds (default: 40).
|
||||
/// Compensates for the round-trip audio latency from playout to mic capture.
|
||||
pub aec_delay_ms: u32,
|
||||
/// Enable adaptive jitter buffer (default: true).
|
||||
///
|
||||
/// When true, the jitter buffer target depth is automatically adjusted
|
||||
/// based on observed inter-arrival jitter (NetEq-inspired algorithm).
|
||||
pub adaptive_jitter: bool,
|
||||
}
|
||||
|
||||
impl Default for CallConfig {
|
||||
@@ -33,6 +59,138 @@ impl Default for CallConfig {
|
||||
jitter_target: 10,
|
||||
jitter_max: 250,
|
||||
jitter_min: 3, // 60ms — low latency start, still smooths jitter
|
||||
suppression_enabled: true,
|
||||
silence_threshold_rms: 100.0,
|
||||
silence_hangover_frames: 5,
|
||||
comfort_noise_level: 50,
|
||||
noise_suppression: true,
|
||||
mini_frames_enabled: true,
|
||||
adaptive_jitter: true,
|
||||
aec_delay_ms: 40,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CallConfig {
|
||||
/// Build a `CallConfig` tuned for the given quality profile.
|
||||
pub fn from_profile(profile: QualityProfile) -> Self {
|
||||
let (jitter_target, jitter_max, jitter_min) = if profile == QualityProfile::CATASTROPHIC {
|
||||
// Catastrophic: larger jitter buffer to absorb spikes
|
||||
(20, 500, 8)
|
||||
} else if profile == QualityProfile::DEGRADED {
|
||||
// Degraded: moderately deeper buffer
|
||||
(15, 350, 5)
|
||||
} else {
|
||||
// Good: low-latency defaults
|
||||
(10, 250, 3)
|
||||
};
|
||||
Self {
|
||||
profile,
|
||||
jitter_target,
|
||||
jitter_max,
|
||||
jitter_min,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sliding-window quality adapter that reacts to relay `QualityReport`s.
|
||||
///
|
||||
/// Thresholds (per-report):
|
||||
/// - loss > 15% OR rtt > 200ms => CATASTROPHIC
|
||||
/// - loss > 5% OR rtt > 100ms => DEGRADED
|
||||
/// - otherwise => GOOD
|
||||
///
|
||||
/// Hysteresis: a profile switch is only recommended after the new profile
|
||||
/// has been the recommendation for 3 or more consecutive reports.
|
||||
pub struct QualityAdapter {
|
||||
/// Sliding window of the last N reports.
|
||||
window: std::collections::VecDeque<QualityReport>,
|
||||
/// Maximum window size.
|
||||
max_window: usize,
|
||||
/// Number of consecutive reports recommending the same (non-current) profile.
|
||||
consecutive_same: u32,
|
||||
/// The profile that the last `consecutive_same` reports recommended.
|
||||
pending_profile: Option<QualityProfile>,
|
||||
}
|
||||
|
||||
/// Number of consecutive reports required before accepting a switch.
|
||||
const HYSTERESIS_COUNT: u32 = 3;
|
||||
/// Default sliding window capacity.
|
||||
const ADAPTER_WINDOW: usize = 10;
|
||||
|
||||
impl QualityAdapter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
window: std::collections::VecDeque::with_capacity(ADAPTER_WINDOW),
|
||||
max_window: ADAPTER_WINDOW,
|
||||
consecutive_same: 0,
|
||||
pending_profile: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new quality report from the relay.
|
||||
pub fn ingest(&mut self, report: &QualityReport) {
|
||||
if self.window.len() >= self.max_window {
|
||||
self.window.pop_front();
|
||||
}
|
||||
self.window.push_back(*report);
|
||||
}
|
||||
|
||||
/// Classify a single report into a recommended profile.
|
||||
fn classify(report: &QualityReport) -> QualityProfile {
|
||||
let loss = report.loss_percent();
|
||||
let rtt = report.rtt_ms();
|
||||
|
||||
if loss > 15.0 || rtt > 200 {
|
||||
QualityProfile::CATASTROPHIC
|
||||
} else if loss > 5.0 || rtt > 100 {
|
||||
QualityProfile::DEGRADED
|
||||
} else {
|
||||
QualityProfile::GOOD
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the best profile based on the most recent report in the window.
|
||||
pub fn recommended_profile(&self) -> QualityProfile {
|
||||
match self.window.back() {
|
||||
Some(report) => Self::classify(report),
|
||||
None => QualityProfile::GOOD,
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a profile switch should happen, applying hysteresis.
|
||||
///
|
||||
/// Returns `Some(new_profile)` only when the recommendation has differed
|
||||
/// from `current` for at least `HYSTERESIS_COUNT` consecutive reports.
|
||||
pub fn should_switch(&mut self, current: &QualityProfile) -> Option<QualityProfile> {
|
||||
let recommended = self.recommended_profile();
|
||||
|
||||
if recommended == *current {
|
||||
// Conditions match current profile — reset pending state.
|
||||
self.consecutive_same = 0;
|
||||
self.pending_profile = None;
|
||||
return None;
|
||||
}
|
||||
|
||||
// Recommended differs from current.
|
||||
match self.pending_profile {
|
||||
Some(pending) if pending == recommended => {
|
||||
self.consecutive_same += 1;
|
||||
}
|
||||
_ => {
|
||||
// New or changed recommendation — restart counter.
|
||||
self.pending_profile = Some(recommended);
|
||||
self.consecutive_same = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if self.consecutive_same >= HYSTERESIS_COUNT {
|
||||
self.consecutive_same = 0;
|
||||
self.pending_profile = None;
|
||||
Some(recommended)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,6 +211,28 @@ pub struct CallEncoder {
|
||||
frame_in_block: u8,
|
||||
/// Timestamp counter (ms).
|
||||
timestamp_ms: u32,
|
||||
/// Acoustic echo canceller (removes speaker echo from mic signal).
|
||||
aec: EchoCanceller,
|
||||
/// Automatic gain control (normalises mic level).
|
||||
agc: AutoGainControl,
|
||||
/// Silence detector for suppression.
|
||||
silence_detector: SilenceDetector,
|
||||
/// Whether silence suppression is enabled.
|
||||
suppression_enabled: bool,
|
||||
/// Total frames suppressed (telemetry).
|
||||
frames_suppressed: u64,
|
||||
/// Frames since last CN packet was sent.
|
||||
cn_counter: u32,
|
||||
/// Comfort noise amplitude level (stored for CN packet payload).
|
||||
cn_level: i16,
|
||||
/// ML-based noise suppressor (RNNoise).
|
||||
denoiser: NoiseSupressor,
|
||||
/// Mini-frame compression context (tracks last full header).
|
||||
mini_context: MiniFrameContext,
|
||||
/// Whether mini-frame header compression is enabled.
|
||||
mini_frames_enabled: bool,
|
||||
/// Frames encoded since the last full header was emitted.
|
||||
frames_since_full: u32,
|
||||
}
|
||||
|
||||
impl CallEncoder {
|
||||
@@ -65,6 +245,37 @@ impl CallEncoder {
|
||||
block_id: 0,
|
||||
frame_in_block: 0,
|
||||
timestamp_ms: 0,
|
||||
aec: EchoCanceller::with_delay(48000, 60, config.aec_delay_ms),
|
||||
agc: AutoGainControl::new(),
|
||||
silence_detector: SilenceDetector::new(
|
||||
config.silence_threshold_rms,
|
||||
config.silence_hangover_frames,
|
||||
),
|
||||
suppression_enabled: config.suppression_enabled,
|
||||
frames_suppressed: 0,
|
||||
cn_counter: 0,
|
||||
cn_level: config.comfort_noise_level,
|
||||
denoiser: {
|
||||
let mut d = NoiseSupressor::new();
|
||||
d.set_enabled(config.noise_suppression);
|
||||
d
|
||||
},
|
||||
mini_context: MiniFrameContext::default(),
|
||||
mini_frames_enabled: config.mini_frames_enabled,
|
||||
frames_since_full: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize a `MediaPacket` for transmission, applying mini-frame
|
||||
/// compression when enabled.
|
||||
///
|
||||
/// Returns compact wire bytes: either `[FRAME_TYPE_FULL][MediaHeader][payload]`
|
||||
/// or `[FRAME_TYPE_MINI][MiniHeader][payload]`.
|
||||
pub fn serialize_compact(&mut self, packet: &MediaPacket) -> Bytes {
|
||||
if self.mini_frames_enabled {
|
||||
packet.encode_compact(&mut self.mini_context, &mut self.frames_since_full)
|
||||
} else {
|
||||
packet.to_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,6 +284,61 @@ impl CallEncoder {
|
||||
/// Input: 48kHz mono PCM, frame size depends on profile (960 for 20ms, 1920 for 40ms).
|
||||
/// Output: one or more MediaPackets to send.
|
||||
pub fn encode_frame(&mut self, pcm: &[i16]) -> Result<Vec<MediaPacket>, anyhow::Error> {
|
||||
// Copy PCM into a mutable buffer for the processing pipeline.
|
||||
let mut pcm_buf = pcm.to_vec();
|
||||
|
||||
// Step 1: Echo cancellation (far-end reference must have been fed already).
|
||||
self.aec.process_frame(&mut pcm_buf);
|
||||
|
||||
// Step 2: Automatic gain control (normalise mic level).
|
||||
self.agc.process_frame(&mut pcm_buf);
|
||||
|
||||
// Step 3: Noise suppression (RNNoise).
|
||||
if self.denoiser.is_enabled() {
|
||||
self.denoiser.process(&mut pcm_buf);
|
||||
}
|
||||
|
||||
let pcm = &pcm_buf[..];
|
||||
|
||||
// Silence suppression: skip encoding silent frames, periodically send CN.
|
||||
if self.suppression_enabled && self.silence_detector.is_silent(pcm) {
|
||||
self.frames_suppressed += 1;
|
||||
self.cn_counter += 1;
|
||||
|
||||
// Advance timestamp even for suppressed frames.
|
||||
self.timestamp_ms = self
|
||||
.timestamp_ms
|
||||
.wrapping_add(self.profile.frame_duration_ms as u32);
|
||||
|
||||
// Every 10 frames (~200ms), send a comfort noise packet.
|
||||
if self.cn_counter % 10 == 0 {
|
||||
let cn_pkt = MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: CodecId::ComfortNoise,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq: self.seq,
|
||||
timestamp: self.timestamp_ms,
|
||||
fec_block: self.block_id,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![self.cn_level as u8]),
|
||||
quality_report: None,
|
||||
};
|
||||
self.seq = self.seq.wrapping_add(1);
|
||||
return Ok(vec![cn_pkt]);
|
||||
}
|
||||
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Not silent — reset CN counter and proceed with normal encoding.
|
||||
self.cn_counter = 0;
|
||||
|
||||
// Encode audio
|
||||
let mut encoded = vec![0u8; self.audio_enc.max_frame_bytes()];
|
||||
let enc_len = self.audio_enc.encode(pcm, &mut encoded)?;
|
||||
@@ -150,6 +416,24 @@ impl CallEncoder {
|
||||
self.frame_in_block = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Feed decoded playout audio as the echo reference signal.
|
||||
///
|
||||
/// Must be called with each decoded frame BEFORE the corresponding
|
||||
/// microphone frame is processed.
|
||||
pub fn feed_aec_farend(&mut self, farend: &[i16]) {
|
||||
self.aec.feed_farend(farend);
|
||||
}
|
||||
|
||||
/// Enable or disable acoustic echo cancellation.
|
||||
pub fn set_aec_enabled(&mut self, enabled: bool) {
|
||||
self.aec.set_enabled(enabled);
|
||||
}
|
||||
|
||||
/// Enable or disable automatic gain control.
|
||||
pub fn set_agc_enabled(&mut self, enabled: bool) {
|
||||
self.agc.set_enabled(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages the recv/decode side of a call.
|
||||
@@ -164,19 +448,42 @@ pub struct CallDecoder {
|
||||
pub quality: AdaptiveQualityController,
|
||||
/// Current profile.
|
||||
profile: QualityProfile,
|
||||
/// Comfort noise generator for filling silent gaps.
|
||||
comfort_noise: ComfortNoise,
|
||||
/// Whether the last decoded frame was comfort noise.
|
||||
last_was_cn: bool,
|
||||
/// Mini-frame decompression context (tracks last full header baseline).
|
||||
mini_context: MiniFrameContext,
|
||||
}
|
||||
|
||||
impl CallDecoder {
|
||||
pub fn new(config: &CallConfig) -> Self {
|
||||
let jitter = if config.adaptive_jitter {
|
||||
JitterBuffer::new_adaptive(config.jitter_min, config.jitter_max)
|
||||
} else {
|
||||
JitterBuffer::new(config.jitter_target, config.jitter_max, config.jitter_min)
|
||||
};
|
||||
Self {
|
||||
audio_dec: wzp_codec::create_decoder(config.profile),
|
||||
fec_dec: wzp_fec::create_decoder(&config.profile),
|
||||
jitter: JitterBuffer::new(config.jitter_target, config.jitter_max, config.jitter_min),
|
||||
jitter,
|
||||
quality: AdaptiveQualityController::new(),
|
||||
profile: config.profile,
|
||||
comfort_noise: ComfortNoise::new(50),
|
||||
last_was_cn: false,
|
||||
mini_context: MiniFrameContext::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize a compact wire-format buffer into a `MediaPacket`,
|
||||
/// auto-detecting full vs mini headers.
|
||||
///
|
||||
/// Returns `None` on malformed data or if a mini-frame arrives before
|
||||
/// any full header baseline has been established.
|
||||
pub fn deserialize_compact(&mut self, buf: &[u8]) -> Option<MediaPacket> {
|
||||
MediaPacket::decode_compact(buf, &mut self.mini_context)
|
||||
}
|
||||
|
||||
/// Feed a received media packet into the decode pipeline.
|
||||
pub fn ingest(&mut self, packet: MediaPacket) {
|
||||
// Feed to FEC decoder
|
||||
@@ -193,31 +500,98 @@ impl CallDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Switch the decoder to match an incoming packet's codec if it differs
|
||||
/// from the current profile. This enables cross-codec interop (e.g. one
|
||||
/// client sends Opus, the other sends Codec2).
|
||||
fn switch_decoder_if_needed(&mut self, incoming_codec: CodecId) {
|
||||
if incoming_codec == self.profile.codec || incoming_codec == CodecId::ComfortNoise {
|
||||
return;
|
||||
}
|
||||
let new_profile = Self::profile_for_codec(incoming_codec);
|
||||
info!(
|
||||
from = ?self.profile.codec,
|
||||
to = ?incoming_codec,
|
||||
"decoder switching codec to match incoming packet"
|
||||
);
|
||||
if let Err(e) = self.audio_dec.set_profile(new_profile) {
|
||||
warn!("failed to switch decoder profile: {e}");
|
||||
return;
|
||||
}
|
||||
self.fec_dec = wzp_fec::create_decoder(&new_profile);
|
||||
self.profile = new_profile;
|
||||
}
|
||||
|
||||
/// Map a `CodecId` to a reasonable `QualityProfile` for decoding.
|
||||
fn profile_for_codec(codec: CodecId) -> QualityProfile {
|
||||
match codec {
|
||||
CodecId::Opus24k => QualityProfile::GOOD,
|
||||
CodecId::Opus16k => QualityProfile {
|
||||
codec: CodecId::Opus16k,
|
||||
fec_ratio: 0.3,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
},
|
||||
CodecId::Opus6k => QualityProfile::DEGRADED,
|
||||
CodecId::Codec2_3200 => QualityProfile {
|
||||
codec: CodecId::Codec2_3200,
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
},
|
||||
CodecId::Codec2_1200 => QualityProfile::CATASTROPHIC,
|
||||
CodecId::ComfortNoise => QualityProfile::GOOD,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode the next audio frame from the jitter buffer.
|
||||
///
|
||||
/// Returns PCM samples (48kHz mono) or None if not ready.
|
||||
pub fn decode_next(&mut self, pcm: &mut [i16]) -> Option<usize> {
|
||||
match self.jitter.pop() {
|
||||
PlayoutResult::Packet(pkt) => {
|
||||
match self.audio_dec.decode(&pkt.payload, pcm) {
|
||||
// Comfort noise packet: generate CN instead of decoding audio.
|
||||
if pkt.header.codec_id == CodecId::ComfortNoise {
|
||||
self.comfort_noise.generate(pcm);
|
||||
self.last_was_cn = true;
|
||||
self.jitter.record_decode();
|
||||
return Some(pcm.len());
|
||||
}
|
||||
|
||||
// Auto-switch decoder if incoming codec differs from current.
|
||||
self.switch_decoder_if_needed(pkt.header.codec_id);
|
||||
|
||||
self.last_was_cn = false;
|
||||
let result = match self.audio_dec.decode(&pkt.payload, pcm) {
|
||||
Ok(n) => Some(n),
|
||||
Err(e) => {
|
||||
warn!("decode error: {e}, using PLC");
|
||||
self.audio_dec.decode_lost(pcm).ok()
|
||||
}
|
||||
};
|
||||
if result.is_some() {
|
||||
self.jitter.record_decode();
|
||||
}
|
||||
result
|
||||
}
|
||||
PlayoutResult::Missing { seq } => {
|
||||
// Only generate PLC if there are still packets buffered ahead.
|
||||
// Otherwise we've drained everything — return None to stop.
|
||||
if self.jitter.depth() > 0 {
|
||||
debug!(seq, "packet loss, generating PLC");
|
||||
self.audio_dec.decode_lost(pcm).ok()
|
||||
let result = self.audio_dec.decode_lost(pcm).ok();
|
||||
if result.is_some() {
|
||||
self.jitter.record_decode();
|
||||
}
|
||||
result
|
||||
} else {
|
||||
self.jitter.record_underrun();
|
||||
None
|
||||
}
|
||||
}
|
||||
PlayoutResult::NotReady => None,
|
||||
PlayoutResult::NotReady => {
|
||||
self.jitter.record_underrun();
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,8 +601,54 @@ impl CallDecoder {
|
||||
}
|
||||
|
||||
/// Get jitter buffer statistics.
|
||||
pub fn jitter_stats(&self) -> wzp_proto::jitter::JitterStats {
|
||||
self.jitter.stats().clone()
|
||||
pub fn stats(&self) -> &wzp_proto::jitter::JitterStats {
|
||||
self.jitter.stats()
|
||||
}
|
||||
|
||||
/// Reset jitter buffer statistics counters.
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.jitter.reset_stats();
|
||||
}
|
||||
}
|
||||
|
||||
/// Periodic telemetry logger for jitter buffer statistics.
|
||||
///
|
||||
/// Call `maybe_log` on each decode tick; it will emit a `tracing::info!` event
|
||||
/// no more frequently than the configured interval.
|
||||
pub struct JitterTelemetry {
|
||||
interval: Duration,
|
||||
last_report: Instant,
|
||||
}
|
||||
|
||||
impl JitterTelemetry {
|
||||
/// Create a new telemetry logger that reports at most once per `interval_secs`.
|
||||
pub fn new(interval_secs: u64) -> Self {
|
||||
Self {
|
||||
interval: Duration::from_secs(interval_secs),
|
||||
last_report: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Log jitter statistics if the interval has elapsed. Returns `true` when a
|
||||
/// log line was emitted.
|
||||
pub fn maybe_log(&mut self, stats: &wzp_proto::jitter::JitterStats) -> bool {
|
||||
let now = Instant::now();
|
||||
if now.duration_since(self.last_report) >= self.interval {
|
||||
info!(
|
||||
buffer_depth = stats.current_depth,
|
||||
underruns = stats.underruns,
|
||||
overruns = stats.overruns,
|
||||
late_packets = stats.packets_late,
|
||||
total_received = stats.packets_received,
|
||||
total_decoded = stats.total_decoded,
|
||||
max_depth_seen = stats.max_depth_seen,
|
||||
"jitter buffer telemetry"
|
||||
);
|
||||
self.last_report = now;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,4 +721,279 @@ mod tests {
|
||||
let mut pcm = vec![0i16; 960];
|
||||
assert!(dec.decode_next(&mut pcm).is_none());
|
||||
}
|
||||
|
||||
// ---- QualityAdapter tests ----
|
||||
|
||||
/// Helper: build a QualityReport from human-readable loss% and RTT ms.
|
||||
fn make_report(loss_pct_f: f32, rtt_ms: u16) -> QualityReport {
|
||||
QualityReport {
|
||||
loss_pct: (loss_pct_f / 100.0 * 255.0) as u8,
|
||||
rtt_4ms: (rtt_ms / 4) as u8,
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn good_conditions_stays_good() {
|
||||
let mut adapter = QualityAdapter::new();
|
||||
let good = make_report(1.0, 40);
|
||||
for _ in 0..10 {
|
||||
adapter.ingest(&good);
|
||||
}
|
||||
assert_eq!(adapter.recommended_profile(), QualityProfile::GOOD);
|
||||
|
||||
let current = QualityProfile::GOOD;
|
||||
for _ in 0..10 {
|
||||
adapter.ingest(&good);
|
||||
assert!(adapter.should_switch(¤t).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_loss_degrades() {
|
||||
let mut adapter = QualityAdapter::new();
|
||||
// 8% loss, low RTT => DEGRADED
|
||||
let degraded = make_report(8.0, 40);
|
||||
let mut current = QualityProfile::GOOD;
|
||||
|
||||
// Feed 3 consecutive degraded reports to pass hysteresis
|
||||
for _ in 0..3 {
|
||||
adapter.ingest(°raded);
|
||||
if let Some(new) = adapter.should_switch(¤t) {
|
||||
current = new;
|
||||
}
|
||||
}
|
||||
assert_eq!(current, QualityProfile::DEGRADED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn catastrophic_conditions() {
|
||||
let mut adapter = QualityAdapter::new();
|
||||
// 20% loss => CATASTROPHIC
|
||||
let terrible = make_report(20.0, 50);
|
||||
let mut current = QualityProfile::GOOD;
|
||||
|
||||
for _ in 0..3 {
|
||||
adapter.ingest(&terrible);
|
||||
if let Some(new) = adapter.should_switch(¤t) {
|
||||
current = new;
|
||||
}
|
||||
}
|
||||
assert_eq!(current, QualityProfile::CATASTROPHIC);
|
||||
|
||||
// Also test via high RTT alone (250ms > 200ms threshold)
|
||||
let mut adapter2 = QualityAdapter::new();
|
||||
let high_rtt = make_report(1.0, 252); // rtt_4ms rounds to 63 => 252ms
|
||||
let mut current2 = QualityProfile::GOOD;
|
||||
|
||||
for _ in 0..3 {
|
||||
adapter2.ingest(&high_rtt);
|
||||
if let Some(new) = adapter2.should_switch(¤t2) {
|
||||
current2 = new;
|
||||
}
|
||||
}
|
||||
assert_eq!(current2, QualityProfile::CATASTROPHIC);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hysteresis_prevents_flapping() {
|
||||
let mut adapter = QualityAdapter::new();
|
||||
let good = make_report(1.0, 40);
|
||||
let bad = make_report(8.0, 40); // DEGRADED
|
||||
let current = QualityProfile::GOOD;
|
||||
|
||||
// Alternate good/bad — should never trigger a switch because
|
||||
// we never get 3 consecutive same-recommendation reports.
|
||||
for _ in 0..20 {
|
||||
adapter.ingest(&bad);
|
||||
assert!(adapter.should_switch(¤t).is_none());
|
||||
adapter.ingest(&good);
|
||||
assert!(adapter.should_switch(¤t).is_none());
|
||||
}
|
||||
assert_eq!(current, QualityProfile::GOOD);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recovery_to_good() {
|
||||
let mut adapter = QualityAdapter::new();
|
||||
let bad = make_report(20.0, 50);
|
||||
let good = make_report(1.0, 40);
|
||||
|
||||
// Drive to CATASTROPHIC first
|
||||
let mut current = QualityProfile::GOOD;
|
||||
for _ in 0..3 {
|
||||
adapter.ingest(&bad);
|
||||
if let Some(new) = adapter.should_switch(¤t) {
|
||||
current = new;
|
||||
}
|
||||
}
|
||||
assert_eq!(current, QualityProfile::CATASTROPHIC);
|
||||
|
||||
// Now feed good reports — should recover to GOOD after 3 consecutive
|
||||
for _ in 0..3 {
|
||||
adapter.ingest(&good);
|
||||
if let Some(new) = adapter.should_switch(¤t) {
|
||||
current = new;
|
||||
}
|
||||
}
|
||||
assert_eq!(current, QualityProfile::GOOD);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_config_from_profile() {
|
||||
let good = CallConfig::from_profile(QualityProfile::GOOD);
|
||||
assert_eq!(good.profile, QualityProfile::GOOD);
|
||||
assert_eq!(good.jitter_min, 3);
|
||||
|
||||
let degraded = CallConfig::from_profile(QualityProfile::DEGRADED);
|
||||
assert_eq!(degraded.profile, QualityProfile::DEGRADED);
|
||||
assert!(degraded.jitter_target > good.jitter_target);
|
||||
|
||||
let catastrophic = CallConfig::from_profile(QualityProfile::CATASTROPHIC);
|
||||
assert_eq!(catastrophic.profile, QualityProfile::CATASTROPHIC);
|
||||
assert!(catastrophic.jitter_max > degraded.jitter_max);
|
||||
}
|
||||
|
||||
// ---- JitterStats telemetry tests ----
|
||||
|
||||
fn make_test_packet(seq: u16) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq,
|
||||
timestamp: seq as u32 * 20,
|
||||
fec_block: 0,
|
||||
fec_symbol: seq as u8,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0u8; 60]),
|
||||
quality_report: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stats_track_ingestion() {
|
||||
let config = CallConfig::default();
|
||||
let mut dec = CallDecoder::new(&config);
|
||||
|
||||
for i in 0..5u16 {
|
||||
dec.ingest(make_test_packet(i));
|
||||
}
|
||||
|
||||
let stats = dec.stats();
|
||||
assert_eq!(stats.packets_received, 5);
|
||||
assert_eq!(stats.current_depth, 5);
|
||||
assert_eq!(stats.max_depth_seen, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stats_track_underruns() {
|
||||
let config = CallConfig::default();
|
||||
let mut dec = CallDecoder::new(&config);
|
||||
|
||||
// Empty buffer — decode_next should record underruns
|
||||
let mut pcm = vec![0i16; 960];
|
||||
dec.decode_next(&mut pcm);
|
||||
dec.decode_next(&mut pcm);
|
||||
dec.decode_next(&mut pcm);
|
||||
|
||||
assert_eq!(dec.stats().underruns, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stats_reset() {
|
||||
let config = CallConfig::default();
|
||||
let mut dec = CallDecoder::new(&config);
|
||||
|
||||
// Generate some stats: ingest packets and trigger underruns on empty buffer
|
||||
for i in 0..3u16 {
|
||||
dec.ingest(make_test_packet(i));
|
||||
}
|
||||
// Also call decode on empty decoder to get underruns
|
||||
let config2 = CallConfig::default();
|
||||
let mut dec2 = CallDecoder::new(&config2);
|
||||
let mut pcm = vec![0i16; 960];
|
||||
dec2.decode_next(&mut pcm); // underrun — nothing in buffer
|
||||
|
||||
assert!(dec.stats().packets_received > 0);
|
||||
assert!(dec2.stats().underruns > 0);
|
||||
|
||||
// Test reset on the decoder with ingested packets
|
||||
dec.reset_stats();
|
||||
let stats = dec.stats();
|
||||
assert_eq!(stats.packets_received, 0);
|
||||
assert_eq!(stats.underruns, 0);
|
||||
assert_eq!(stats.overruns, 0);
|
||||
assert_eq!(stats.total_decoded, 0);
|
||||
assert_eq!(stats.packets_late, 0);
|
||||
assert_eq!(stats.max_depth_seen, 0);
|
||||
|
||||
// Test reset on the decoder with underruns
|
||||
dec2.reset_stats();
|
||||
assert_eq!(dec2.stats().underruns, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_respects_interval() {
|
||||
use wzp_proto::jitter::JitterStats;
|
||||
|
||||
let mut telemetry = JitterTelemetry::new(60); // 60-second interval
|
||||
let stats = JitterStats::default();
|
||||
|
||||
// First call right after creation — should not log because no time has passed
|
||||
// (the interval hasn't elapsed since construction)
|
||||
let logged = telemetry.maybe_log(&stats);
|
||||
assert!(!logged, "should not log before interval elapses");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silence_suppression_skips_silent_frames() {
|
||||
let config = CallConfig {
|
||||
suppression_enabled: true,
|
||||
silence_threshold_rms: 100.0,
|
||||
silence_hangover_frames: 5,
|
||||
comfort_noise_level: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let mut enc = CallEncoder::new(&config);
|
||||
|
||||
let silence = vec![0i16; 960];
|
||||
let mut total_packets = 0;
|
||||
let mut cn_packets = 0;
|
||||
|
||||
for _ in 0..20 {
|
||||
let packets = enc.encode_frame(&silence).unwrap();
|
||||
for p in &packets {
|
||||
if p.header.codec_id == CodecId::ComfortNoise {
|
||||
cn_packets += 1;
|
||||
// CN payload should be a single byte with the noise level.
|
||||
assert_eq!(p.payload.len(), 1);
|
||||
}
|
||||
}
|
||||
total_packets += packets.len();
|
||||
}
|
||||
|
||||
// First 5 frames are hangover (not suppressed) => 5 normal source packets
|
||||
// (plus potential repair packets from FEC block completion).
|
||||
// Remaining 15 frames are suppressed; CN every 10 frames => 1 CN packet
|
||||
// (cn_counter hits 10 on the 10th suppressed frame).
|
||||
assert!(
|
||||
total_packets < 20,
|
||||
"suppression should reduce packet count, got {total_packets}"
|
||||
);
|
||||
assert!(
|
||||
cn_packets >= 1,
|
||||
"should have at least one CN packet, got {cn_packets}"
|
||||
);
|
||||
assert!(
|
||||
enc.frames_suppressed > 0,
|
||||
"frames_suppressed should be > 0"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
293
crates/wzp-client/src/drift_test.rs
Normal file
293
crates/wzp-client/src/drift_test.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
//! Automated clock-drift measurement tool.
|
||||
//!
|
||||
//! Sends N seconds of a known 440 Hz tone through the transport, records
|
||||
//! received frame timestamps on the other side, and compares actual received
|
||||
//! duration vs expected duration to quantify timing drift and packet loss.
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tracing::info;
|
||||
|
||||
use wzp_proto::MediaTransport;
|
||||
|
||||
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
|
||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||
const SAMPLE_RATE: u32 = 48_000;
|
||||
|
||||
/// Configuration for a drift measurement run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriftTestConfig {
|
||||
/// How many seconds of tone to send.
|
||||
pub duration_secs: u32,
|
||||
/// Frequency of the test tone (Hz).
|
||||
pub tone_freq_hz: f32,
|
||||
}
|
||||
|
||||
impl Default for DriftTestConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
duration_secs: 10,
|
||||
tone_freq_hz: 440.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Results from a drift measurement run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriftResult {
|
||||
/// Expected duration in milliseconds (`duration_secs * 1000`).
|
||||
pub expected_duration_ms: u64,
|
||||
/// Actual measured duration in milliseconds (last_recv - first_recv).
|
||||
pub actual_duration_ms: u64,
|
||||
/// Drift: `actual - expected` (positive = receiver clock ran slow / packets delayed).
|
||||
pub drift_ms: i64,
|
||||
/// Drift as a percentage of expected duration.
|
||||
pub drift_pct: f64,
|
||||
/// Total frames sent by the sender.
|
||||
pub frames_sent: u64,
|
||||
/// Total frames successfully received and decoded.
|
||||
pub frames_received: u64,
|
||||
/// Packet loss percentage: `(1 - frames_received / frames_sent) * 100`.
|
||||
pub loss_pct: f64,
|
||||
}
|
||||
|
||||
impl DriftResult {
|
||||
/// Compute a `DriftResult` from raw counters and timestamps.
|
||||
pub fn compute(
|
||||
expected_duration_ms: u64,
|
||||
actual_duration_ms: u64,
|
||||
frames_sent: u64,
|
||||
frames_received: u64,
|
||||
) -> Self {
|
||||
let drift_ms = actual_duration_ms as i64 - expected_duration_ms as i64;
|
||||
let drift_pct = if expected_duration_ms > 0 {
|
||||
drift_ms as f64 / expected_duration_ms as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let loss_pct = if frames_sent > 0 {
|
||||
(1.0 - frames_received as f64 / frames_sent as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
Self {
|
||||
expected_duration_ms,
|
||||
actual_duration_ms,
|
||||
drift_ms,
|
||||
drift_pct,
|
||||
frames_sent,
|
||||
frames_received,
|
||||
loss_pct,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a sine wave frame at a given frequency.
|
||||
fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
|
||||
let start = frame_offset * FRAME_SAMPLES as u64;
|
||||
(0..FRAME_SAMPLES)
|
||||
.map(|i| {
|
||||
let t = (start + i as u64) as f32 / SAMPLE_RATE as f32;
|
||||
(f32::sin(2.0 * std::f32::consts::PI * freq_hz * t) * 16000.0) as i16
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Run the drift measurement test.
|
||||
///
|
||||
/// 1. Spawns a send task that encodes `duration_secs` of tone at 20 ms intervals.
|
||||
/// 2. Spawns a recv task that counts decoded frames and tracks first/last timestamps.
|
||||
/// 3. After the sender finishes, waits 2 seconds for trailing packets.
|
||||
/// 4. Computes and returns the `DriftResult`.
|
||||
pub async fn run_drift_test(
|
||||
transport: &(dyn MediaTransport + Send + Sync),
|
||||
config: &DriftTestConfig,
|
||||
) -> anyhow::Result<DriftResult> {
|
||||
let call_config = CallConfig::default();
|
||||
let mut encoder = CallEncoder::new(&call_config);
|
||||
let mut decoder = CallDecoder::new(&call_config);
|
||||
|
||||
let total_frames: u64 = config.duration_secs as u64 * 50; // 50 frames/s at 20 ms
|
||||
let frame_duration = Duration::from_millis(20);
|
||||
let mut pcm_buf = vec![0i16; FRAME_SAMPLES];
|
||||
|
||||
let mut frames_sent: u64 = 0;
|
||||
let mut frames_received: u64 = 0;
|
||||
let mut first_recv_time: Option<Instant> = None;
|
||||
let mut last_recv_time: Option<Instant> = None;
|
||||
|
||||
info!(
|
||||
duration_secs = config.duration_secs,
|
||||
tone_hz = config.tone_freq_hz,
|
||||
total_frames = total_frames,
|
||||
"starting drift measurement"
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Send + interleaved receive loop (same pattern as echo_test)
|
||||
for frame_idx in 0..total_frames {
|
||||
// --- send ---
|
||||
let pcm = sine_frame(config.tone_freq_hz, frame_idx);
|
||||
let packets = encoder.encode_frame(&pcm)?;
|
||||
for pkt in &packets {
|
||||
transport.send_media(pkt).await?;
|
||||
}
|
||||
frames_sent += 1;
|
||||
|
||||
// --- try to receive (short window so we don't block the sender) ---
|
||||
let recv_deadline = Instant::now() + Duration::from_millis(5);
|
||||
loop {
|
||||
if Instant::now() >= recv_deadline {
|
||||
break;
|
||||
}
|
||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
let now = Instant::now();
|
||||
if first_recv_time.is_none() {
|
||||
first_recv_time = Some(now);
|
||||
}
|
||||
last_recv_time = Some(now);
|
||||
frames_received += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
if (frame_idx + 1) % 250 == 0 {
|
||||
info!(
|
||||
frame = frame_idx + 1,
|
||||
sent = frames_sent,
|
||||
recv = frames_received,
|
||||
elapsed = format!("{:.1}s", start.elapsed().as_secs_f64()),
|
||||
"drift-test progress"
|
||||
);
|
||||
}
|
||||
|
||||
tokio::time::sleep(frame_duration).await;
|
||||
}
|
||||
|
||||
// Drain trailing packets for 2 seconds
|
||||
info!("sender done, draining trailing packets for 2s...");
|
||||
let drain_deadline = Instant::now() + Duration::from_secs(2);
|
||||
while Instant::now() < drain_deadline {
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
let now = Instant::now();
|
||||
if first_recv_time.is_none() {
|
||||
first_recv_time = Some(now);
|
||||
}
|
||||
last_recv_time = Some(now);
|
||||
frames_received += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Compute result
|
||||
let expected_duration_ms = config.duration_secs as u64 * 1000;
|
||||
let actual_duration_ms = match (first_recv_time, last_recv_time) {
|
||||
(Some(first), Some(last)) => last.duration_since(first).as_millis() as u64,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
let result = DriftResult::compute(
|
||||
expected_duration_ms,
|
||||
actual_duration_ms,
|
||||
frames_sent,
|
||||
frames_received,
|
||||
);
|
||||
|
||||
info!(
|
||||
expected_ms = result.expected_duration_ms,
|
||||
actual_ms = result.actual_duration_ms,
|
||||
drift_ms = result.drift_ms,
|
||||
drift_pct = format!("{:.4}%", result.drift_pct),
|
||||
loss_pct = format!("{:.1}%", result.loss_pct),
|
||||
"drift measurement complete"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Pretty-print the drift measurement results.
|
||||
pub fn print_drift_report(result: &DriftResult) {
|
||||
println!();
|
||||
println!("=== Drift Measurement Report ===");
|
||||
println!();
|
||||
println!("Frames sent: {}", result.frames_sent);
|
||||
println!("Frames received: {}", result.frames_received);
|
||||
println!("Packet loss: {:.1}%", result.loss_pct);
|
||||
println!();
|
||||
println!("Expected duration: {} ms", result.expected_duration_ms);
|
||||
println!("Actual duration: {} ms", result.actual_duration_ms);
|
||||
println!("Drift: {} ms ({:+.4}%)", result.drift_ms, result.drift_pct);
|
||||
println!();
|
||||
|
||||
// Interpretation
|
||||
let abs_drift = result.drift_ms.unsigned_abs();
|
||||
if result.frames_received == 0 {
|
||||
println!("WARNING: No frames received. Transport may be non-functional.");
|
||||
} else if abs_drift < 5 {
|
||||
println!("Result: EXCELLENT -- drift is negligible (<5 ms).");
|
||||
} else if abs_drift < 20 {
|
||||
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
|
||||
} else if abs_drift < 100 {
|
||||
println!("Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.", abs_drift);
|
||||
} else {
|
||||
println!("Result: POOR -- significant drift ({} ms). Investigate clock sources.", abs_drift);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn drift_result_calculations() {
|
||||
// Perfect case: no drift, no loss
|
||||
let r = DriftResult::compute(10_000, 10_000, 500, 500);
|
||||
assert_eq!(r.drift_ms, 0);
|
||||
assert!((r.drift_pct - 0.0).abs() < f64::EPSILON);
|
||||
assert!((r.loss_pct - 0.0).abs() < f64::EPSILON);
|
||||
|
||||
// Positive drift (receiver duration longer than expected)
|
||||
let r = DriftResult::compute(10_000, 10_050, 500, 490);
|
||||
assert_eq!(r.drift_ms, 50);
|
||||
assert!((r.drift_pct - 0.5).abs() < 1e-9); // 50/10000 * 100 = 0.5%
|
||||
assert!((r.loss_pct - 2.0).abs() < 1e-9); // (1 - 490/500) * 100 = 2.0%
|
||||
|
||||
// Negative drift (receiver duration shorter than expected)
|
||||
let r = DriftResult::compute(10_000, 9_900, 500, 450);
|
||||
assert_eq!(r.drift_ms, -100);
|
||||
assert!((r.drift_pct - (-1.0)).abs() < 1e-9); // -100/10000 * 100 = -1.0%
|
||||
assert!((r.loss_pct - 10.0).abs() < 1e-9); // (1 - 450/500) * 100 = 10.0%
|
||||
|
||||
// Edge: zero frames sent (avoid division by zero)
|
||||
let r = DriftResult::compute(0, 0, 0, 0);
|
||||
assert_eq!(r.drift_ms, 0);
|
||||
assert!((r.drift_pct - 0.0).abs() < f64::EPSILON);
|
||||
assert!((r.loss_pct - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drift_config_defaults() {
|
||||
let cfg = DriftTestConfig::default();
|
||||
assert_eq!(cfg.duration_secs, 10);
|
||||
assert!((cfg.tone_freq_hz - 440.0).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -266,7 +266,7 @@ pub async fn run_echo_test(
|
||||
}
|
||||
}
|
||||
|
||||
let jitter_stats = decoder.jitter_stats();
|
||||
let jitter_stats = decoder.stats().clone();
|
||||
let total_frames_received = recv_pcm.len() as u64 / FRAME_SAMPLES as u64;
|
||||
let overall_loss = if total_frames > 0 {
|
||||
(1.0 - total_frames_received as f32 / total_frames as f32) * 100.0
|
||||
|
||||
167
crates/wzp-client/src/featherchat.rs
Normal file
167
crates/wzp-client/src/featherchat.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
//! featherChat signaling bridge.
|
||||
//!
|
||||
//! Sends WZP call signaling (Offer/Answer/Hangup) through featherChat's
|
||||
//! E2E encrypted WebSocket channel as `WireMessage::CallSignal`.
|
||||
//!
|
||||
//! Flow:
|
||||
//! 1. Client connects to featherChat WS with bearer token
|
||||
//! 2. Sends CallOffer as CallSignal(signal_type=Offer, payload=JSON SignalMessage)
|
||||
//! 3. Receives CallAnswer as CallSignal(signal_type=Answer, payload=JSON SignalMessage)
|
||||
//! 4. Extracts relay address from the answer
|
||||
//! 5. Connects QUIC to relay for media
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wzp_proto::packet::SignalMessage;
|
||||
|
||||
/// featherChat CallSignal types (mirrors warzone-protocol::message::CallSignalType).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum CallSignalType {
|
||||
Offer,
|
||||
Answer,
|
||||
IceCandidate,
|
||||
Hangup,
|
||||
Reject,
|
||||
Ringing,
|
||||
Busy,
|
||||
Hold,
|
||||
Unhold,
|
||||
Mute,
|
||||
Unmute,
|
||||
Transfer,
|
||||
}
|
||||
|
||||
/// A CallSignal as sent through featherChat's WireMessage.
|
||||
/// This is what goes in the `payload` field of `WireMessage::CallSignal`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct WzpCallPayload {
|
||||
/// The WZP SignalMessage (CallOffer, CallAnswer, etc.) serialized as JSON.
|
||||
pub signal: SignalMessage,
|
||||
/// The relay address to connect to for media (host:port).
|
||||
pub relay_addr: Option<String>,
|
||||
/// Room name on the relay.
|
||||
pub room: Option<String>,
|
||||
}
|
||||
|
||||
/// Parameters for initiating a call through featherChat.
|
||||
pub struct CallInitParams {
|
||||
/// featherChat server URL (e.g., "wss://chat.example.com/ws").
|
||||
pub server_url: String,
|
||||
/// Bearer token for authentication.
|
||||
pub token: String,
|
||||
/// Target peer fingerprint (who to call).
|
||||
pub target_fingerprint: String,
|
||||
/// Relay address for media transport.
|
||||
pub relay_addr: String,
|
||||
/// Room name on the relay.
|
||||
pub room: String,
|
||||
/// Our identity seed for crypto.
|
||||
pub seed: [u8; 32],
|
||||
}
|
||||
|
||||
/// Result of a successful call setup.
|
||||
pub struct CallSetupResult {
|
||||
/// Relay address to connect to.
|
||||
pub relay_addr: String,
|
||||
/// Room name.
|
||||
pub room: String,
|
||||
/// The peer's CallAnswer signal (contains ephemeral key, etc.)
|
||||
pub answer: SignalMessage,
|
||||
}
|
||||
|
||||
/// Serialize a WZP SignalMessage into a featherChat CallSignal payload string.
|
||||
pub fn encode_call_payload(
|
||||
signal: &SignalMessage,
|
||||
relay_addr: Option<&str>,
|
||||
room: Option<&str>,
|
||||
) -> String {
|
||||
let payload = WzpCallPayload {
|
||||
signal: signal.clone(),
|
||||
relay_addr: relay_addr.map(|s| s.to_string()),
|
||||
room: room.map(|s| s.to_string()),
|
||||
};
|
||||
serde_json::to_string(&payload).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Deserialize a featherChat CallSignal payload back to WZP types.
|
||||
pub fn decode_call_payload(payload: &str) -> Result<WzpCallPayload, String> {
|
||||
serde_json::from_str(payload).map_err(|e| format!("invalid call payload: {e}"))
|
||||
}
|
||||
|
||||
/// Map WZP SignalMessage type to featherChat CallSignalType.
|
||||
pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
match signal {
|
||||
SignalMessage::CallOffer { .. } => CallSignalType::Offer,
|
||||
SignalMessage::CallAnswer { .. } => CallSignalType::Answer,
|
||||
SignalMessage::IceCandidate { .. } => CallSignalType::IceCandidate,
|
||||
SignalMessage::Hangup { .. } => CallSignalType::Hangup,
|
||||
SignalMessage::Rekey { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::QualityUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
|
||||
SignalMessage::AuthToken { .. } => CallSignalType::Offer,
|
||||
SignalMessage::Hold => CallSignalType::Hold,
|
||||
SignalMessage::Unhold => CallSignalType::Unhold,
|
||||
SignalMessage::Mute => CallSignalType::Mute,
|
||||
SignalMessage::Unmute => CallSignalType::Unmute,
|
||||
SignalMessage::Transfer { .. } => CallSignalType::Transfer,
|
||||
SignalMessage::TransferAck => CallSignalType::Offer, // reuse
|
||||
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RoomUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SetAlias { .. } => CallSignalType::Offer, // reuse
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
#[test]
|
||||
fn payload_roundtrip() {
|
||||
let signal = SignalMessage::CallOffer {
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![QualityProfile::GOOD],
|
||||
alias: None,
|
||||
};
|
||||
|
||||
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
|
||||
let decoded = decode_call_payload(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded.relay_addr.unwrap(), "relay.example.com:4433");
|
||||
assert_eq!(decoded.room.unwrap(), "myroom");
|
||||
assert!(matches!(decoded.signal, SignalMessage::CallOffer { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_type_mapping() {
|
||||
let offer = SignalMessage::CallOffer {
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
supported_profiles: vec![],
|
||||
alias: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
|
||||
|
||||
let hangup = SignalMessage::Hangup {
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&hangup), CallSignalType::Hangup));
|
||||
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Hold), CallSignalType::Hold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unhold), CallSignalType::Unhold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Mute), CallSignalType::Mute));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unmute), CallSignalType::Unmute));
|
||||
|
||||
let transfer = SignalMessage::Transfer {
|
||||
target_fingerprint: "abc".to_string(),
|
||||
relay_addr: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&transfer), CallSignalType::Transfer));
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
pub async fn perform_handshake(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
alias: Option<&str>,
|
||||
) -> Result<Box<dyn CryptoSession>, anyhow::Error> {
|
||||
// 1. Create key exchange from identity seed
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
|
||||
@@ -41,6 +42,7 @@ pub async fn perform_handshake(
|
||||
QualityProfile::DEGRADED,
|
||||
QualityProfile::CATASTROPHIC,
|
||||
],
|
||||
alias: alias.map(|s| s.to_string()),
|
||||
};
|
||||
transport.send_signal(&offer).await?;
|
||||
|
||||
|
||||
@@ -8,10 +8,18 @@
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio_io;
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio_ring;
|
||||
#[cfg(feature = "vpio")]
|
||||
pub mod audio_vpio;
|
||||
pub mod bench;
|
||||
pub mod call;
|
||||
pub mod drift_test;
|
||||
pub mod echo_test;
|
||||
pub mod featherchat;
|
||||
pub mod handshake;
|
||||
pub mod metrics;
|
||||
pub mod sweep;
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
pub use audio_io::{AudioCapture, AudioPlayback};
|
||||
|
||||
186
crates/wzp-client/src/metrics.rs
Normal file
186
crates/wzp-client/src/metrics.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
//! Client-side JSONL metrics export.
|
||||
//!
|
||||
//! When `--metrics-file <path>` is passed, the client writes one JSON object
|
||||
//! per second to the specified file. Each line is a self-contained JSON object
|
||||
//! (JSONL format) containing jitter buffer stats, loss, and quality profile.
|
||||
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use wzp_proto::jitter::JitterStats;
|
||||
|
||||
/// A single metrics snapshot written as one JSONL line.
|
||||
#[derive(Serialize)]
|
||||
pub struct ClientMetricsSnapshot {
|
||||
pub ts: String,
|
||||
pub buffer_depth: usize,
|
||||
pub underruns: u64,
|
||||
pub overruns: u64,
|
||||
pub loss_pct: f64,
|
||||
pub rtt_ms: u64,
|
||||
pub jitter_ms: u64,
|
||||
pub frames_sent: u64,
|
||||
pub frames_received: u64,
|
||||
pub quality_profile: String,
|
||||
}
|
||||
|
||||
/// Periodic JSONL writer that respects a configurable interval.
|
||||
pub struct MetricsWriter {
|
||||
file: File,
|
||||
interval: Duration,
|
||||
last_write: Instant,
|
||||
}
|
||||
|
||||
impl MetricsWriter {
|
||||
/// Create a new `MetricsWriter` that appends JSONL to the given path.
|
||||
///
|
||||
/// The file is created (or truncated) immediately.
|
||||
pub fn new(path: &str, interval_secs: u64) -> Result<Self, anyhow::Error> {
|
||||
let file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(path)?;
|
||||
Ok(Self {
|
||||
file,
|
||||
interval: Duration::from_secs(interval_secs),
|
||||
// Set last_write far in the past so the first call writes immediately.
|
||||
last_write: Instant::now() - Duration::from_secs(interval_secs + 1),
|
||||
})
|
||||
}
|
||||
|
||||
/// Write a JSONL line if the interval has elapsed since the last write.
|
||||
///
|
||||
/// Returns `Ok(true)` when a line was written, `Ok(false)` when skipped.
|
||||
pub fn maybe_write(&mut self, snapshot: &ClientMetricsSnapshot) -> Result<bool, anyhow::Error> {
|
||||
let now = Instant::now();
|
||||
if now.duration_since(self.last_write) >= self.interval {
|
||||
let line = serde_json::to_string(snapshot)?;
|
||||
writeln!(self.file, "{}", line)?;
|
||||
self.file.flush()?;
|
||||
self.last_write = now;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a `ClientMetricsSnapshot` from jitter buffer stats and a quality profile name.
|
||||
///
|
||||
/// Fields not available from `JitterStats` alone (rtt_ms, jitter_ms, frames_sent)
|
||||
/// are set to zero — the caller can override them if the data is available.
|
||||
pub fn snapshot_from_stats(stats: &JitterStats, profile: &str) -> ClientMetricsSnapshot {
|
||||
let loss_pct = if stats.packets_received > 0 {
|
||||
(stats.packets_lost as f64 / stats.packets_received as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
ClientMetricsSnapshot {
|
||||
ts: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
|
||||
buffer_depth: stats.current_depth,
|
||||
underruns: stats.underruns,
|
||||
overruns: stats.overruns,
|
||||
loss_pct,
|
||||
rtt_ms: 0,
|
||||
jitter_ms: 0,
|
||||
frames_sent: 0,
|
||||
frames_received: stats.total_decoded,
|
||||
quality_profile: profile.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_stats() -> JitterStats {
|
||||
JitterStats {
|
||||
packets_received: 100,
|
||||
packets_played: 95,
|
||||
packets_lost: 5,
|
||||
packets_late: 2,
|
||||
packets_duplicate: 0,
|
||||
current_depth: 8,
|
||||
total_decoded: 93,
|
||||
underruns: 1,
|
||||
overruns: 0,
|
||||
max_depth_seen: 12,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn snapshot_serializes_to_json() {
|
||||
let stats = make_test_stats();
|
||||
let snap = snapshot_from_stats(&stats, "GOOD");
|
||||
let json = serde_json::to_string(&snap).unwrap();
|
||||
|
||||
// Verify expected fields are present in the JSON string.
|
||||
assert!(json.contains("\"ts\""));
|
||||
assert!(json.contains("\"buffer_depth\":8"));
|
||||
assert!(json.contains("\"underruns\":1"));
|
||||
assert!(json.contains("\"overruns\":0"));
|
||||
assert!(json.contains("\"loss_pct\":5."));
|
||||
assert!(json.contains("\"rtt_ms\":0"));
|
||||
assert!(json.contains("\"jitter_ms\":0"));
|
||||
assert!(json.contains("\"frames_sent\":0"));
|
||||
assert!(json.contains("\"frames_received\":93"));
|
||||
assert!(json.contains("\"quality_profile\":\"GOOD\""));
|
||||
|
||||
// Verify it round-trips as valid JSON.
|
||||
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(value["buffer_depth"], 8);
|
||||
assert_eq!(value["quality_profile"], "GOOD");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_writer_creates_file() {
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("wzp_metrics_test.jsonl");
|
||||
let path_str = path.to_str().unwrap();
|
||||
|
||||
let mut writer = MetricsWriter::new(path_str, 1).unwrap();
|
||||
let stats = make_test_stats();
|
||||
let snap = snapshot_from_stats(&stats, "DEGRADED");
|
||||
|
||||
let wrote = writer.maybe_write(&snap).unwrap();
|
||||
assert!(wrote, "first write should succeed immediately");
|
||||
|
||||
// Read the file back and verify it contains valid JSONL.
|
||||
let contents = std::fs::read_to_string(&path).unwrap();
|
||||
let lines: Vec<&str> = contents.lines().collect();
|
||||
assert_eq!(lines.len(), 1, "should have exactly one JSONL line");
|
||||
|
||||
let value: serde_json::Value = serde_json::from_str(lines[0]).unwrap();
|
||||
assert_eq!(value["quality_profile"], "DEGRADED");
|
||||
assert_eq!(value["buffer_depth"], 8);
|
||||
|
||||
// Clean up.
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_writer_respects_interval() {
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("wzp_metrics_interval_test.jsonl");
|
||||
let path_str = path.to_str().unwrap();
|
||||
|
||||
let mut writer = MetricsWriter::new(path_str, 60).unwrap();
|
||||
let stats = make_test_stats();
|
||||
let snap = snapshot_from_stats(&stats, "GOOD");
|
||||
|
||||
// First write succeeds (last_write is set far in the past).
|
||||
let first = writer.maybe_write(&snap).unwrap();
|
||||
assert!(first, "first write should succeed");
|
||||
|
||||
// Immediate second write should be skipped (60s interval).
|
||||
let second = writer.maybe_write(&snap).unwrap();
|
||||
assert!(!second, "second write should be skipped — interval not elapsed");
|
||||
|
||||
// Clean up.
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
}
|
||||
254
crates/wzp-client/src/sweep.rs
Normal file
254
crates/wzp-client/src/sweep.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
//! Parameter sweep tool for jitter buffer configurations.
|
||||
//!
|
||||
//! Tests different (target_depth, max_depth) combinations in a local
|
||||
//! encoder-to-decoder pipeline (no network) and reports frame loss,
|
||||
//! estimated latency, underruns, and overruns for each configuration.
|
||||
|
||||
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||
const SAMPLE_RATE: u32 = 48_000;
|
||||
const FRAME_DURATION_MS: u32 = 20;
|
||||
|
||||
/// Configuration for a parameter sweep.
|
||||
pub struct SweepConfig {
|
||||
/// Target jitter buffer depths to test (in packets).
|
||||
pub target_depths: Vec<usize>,
|
||||
/// Maximum jitter buffer depths to test (in packets).
|
||||
pub max_depths: Vec<usize>,
|
||||
/// Duration in seconds to run each configuration.
|
||||
pub test_duration_secs: u32,
|
||||
/// Frequency of the test tone in Hz.
|
||||
pub tone_freq_hz: f32,
|
||||
}
|
||||
|
||||
impl Default for SweepConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_depths: vec![10, 25, 50, 100, 200],
|
||||
max_depths: vec![50, 100, 250, 500],
|
||||
test_duration_secs: 2,
|
||||
tone_freq_hz: 440.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result from one (target_depth, max_depth) configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SweepResult {
|
||||
/// Jitter buffer target depth used.
|
||||
pub target_depth: usize,
|
||||
/// Jitter buffer max depth used.
|
||||
pub max_depth: usize,
|
||||
/// Total frames sent into the encoder.
|
||||
pub frames_sent: u64,
|
||||
/// Total frames successfully decoded.
|
||||
pub frames_received: u64,
|
||||
/// Frame loss percentage.
|
||||
pub loss_pct: f64,
|
||||
/// Estimated latency in ms (target_depth * frame_duration).
|
||||
pub avg_latency_ms: f64,
|
||||
/// Number of jitter buffer underruns.
|
||||
pub underruns: u64,
|
||||
/// Number of jitter buffer overruns (packets dropped due to full buffer).
|
||||
pub overruns: u64,
|
||||
}
|
||||
|
||||
/// Generate a sine wave frame at the given frequency and frame offset.
|
||||
fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
|
||||
let start = frame_offset * FRAME_SAMPLES as u64;
|
||||
(0..FRAME_SAMPLES)
|
||||
.map(|i| {
|
||||
let t = (start + i as u64) as f32 / SAMPLE_RATE as f32;
|
||||
(f32::sin(2.0 * std::f32::consts::PI * freq_hz * t) * 16000.0) as i16
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Run a local parameter sweep (no network).
|
||||
///
|
||||
/// For each (target_depth, max_depth) combination, creates an encoder and
|
||||
/// decoder, pushes frames through the pipeline, and collects statistics.
|
||||
/// Combinations where `target_depth > max_depth` are skipped.
|
||||
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
|
||||
let frames_per_config =
|
||||
(config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for &target in &config.target_depths {
|
||||
for &max in &config.max_depths {
|
||||
// Skip invalid combinations where target exceeds max.
|
||||
if target > max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let call_cfg = CallConfig {
|
||||
profile: QualityProfile::GOOD,
|
||||
jitter_target: target,
|
||||
jitter_max: max,
|
||||
jitter_min: target.min(3).max(1),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut encoder = CallEncoder::new(&call_cfg);
|
||||
let mut decoder = CallDecoder::new(&call_cfg);
|
||||
|
||||
let mut pcm_out = vec![0i16; FRAME_SAMPLES];
|
||||
let mut frames_decoded = 0u64;
|
||||
|
||||
for frame_idx in 0..frames_per_config {
|
||||
// Encode a tone frame.
|
||||
let pcm_in = sine_frame(config.tone_freq_hz, frame_idx);
|
||||
let packets = match encoder.encode_frame(&pcm_in) {
|
||||
Ok(p) => p,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Feed all packets (source + repair) into the decoder.
|
||||
for pkt in packets {
|
||||
decoder.ingest(pkt);
|
||||
}
|
||||
|
||||
// Attempt to decode one frame.
|
||||
if decoder.decode_next(&mut pcm_out).is_some() {
|
||||
frames_decoded += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Drain: keep decoding until the jitter buffer is empty.
|
||||
for _ in 0..max {
|
||||
if decoder.decode_next(&mut pcm_out).is_some() {
|
||||
frames_decoded += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let stats = decoder.stats().clone();
|
||||
|
||||
let loss_pct = if frames_per_config > 0 {
|
||||
(1.0 - frames_decoded as f64 / frames_per_config as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
results.push(SweepResult {
|
||||
target_depth: target,
|
||||
max_depth: max,
|
||||
frames_sent: frames_per_config,
|
||||
frames_received: frames_decoded,
|
||||
loss_pct: loss_pct.max(0.0),
|
||||
avg_latency_ms: target as f64 * FRAME_DURATION_MS as f64,
|
||||
underruns: stats.underruns,
|
||||
overruns: stats.overruns,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Print a formatted ASCII table of sweep results.
|
||||
pub fn print_sweep_table(results: &[SweepResult]) {
|
||||
println!();
|
||||
println!("=== Jitter Buffer Parameter Sweep ===");
|
||||
println!();
|
||||
println!(
|
||||
" {:>6} | {:>4} | {:>6} | {:>6} | {:>6} | {:>10} | {:>9} | {:>8}",
|
||||
"target", "max", "sent", "recv", "loss%", "latency_ms", "underruns", "overruns"
|
||||
);
|
||||
println!(
|
||||
" {:-<6}-+-{:-<4}-+-{:-<6}-+-{:-<6}-+-{:-<6}-+-{:-<10}-+-{:-<9}-+-{:-<8}",
|
||||
"", "", "", "", "", "", "", ""
|
||||
);
|
||||
for r in results {
|
||||
println!(
|
||||
" {:>6} | {:>4} | {:>6} | {:>6} | {:>5.1}% | {:>10.0} | {:>9} | {:>8}",
|
||||
r.target_depth,
|
||||
r.max_depth,
|
||||
r.frames_sent,
|
||||
r.frames_received,
|
||||
r.loss_pct,
|
||||
r.avg_latency_ms,
|
||||
r.underruns,
|
||||
r.overruns,
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
/// Run a default sweep and print the results.
|
||||
///
|
||||
/// This is the entry point for the `--sweep` CLI flag.
|
||||
pub fn run_and_print_default_sweep() {
|
||||
let config = SweepConfig::default();
|
||||
let results = run_local_sweep(&config);
|
||||
print_sweep_table(&results);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sweep_config_default() {
|
||||
let cfg = SweepConfig::default();
|
||||
assert_eq!(cfg.target_depths.len(), 5);
|
||||
assert_eq!(cfg.max_depths.len(), 4);
|
||||
assert!(cfg.test_duration_secs > 0);
|
||||
assert!(cfg.tone_freq_hz > 0.0);
|
||||
// All default targets should be positive.
|
||||
assert!(cfg.target_depths.iter().all(|&d| d > 0));
|
||||
assert!(cfg.max_depths.iter().all(|&d| d > 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_sweep_runs() {
|
||||
let cfg = SweepConfig {
|
||||
target_depths: vec![3, 10],
|
||||
max_depths: vec![50, 100],
|
||||
test_duration_secs: 1,
|
||||
tone_freq_hz: 440.0,
|
||||
};
|
||||
let results = run_local_sweep(&cfg);
|
||||
// 2 targets x 2 maxes = 4 configs (all valid since targets < maxes).
|
||||
assert_eq!(results.len(), 4);
|
||||
for r in &results {
|
||||
assert!(r.frames_sent > 0, "frames_sent should be > 0");
|
||||
assert!(r.frames_received > 0, "frames_received should be > 0");
|
||||
assert!(r.avg_latency_ms > 0.0, "latency should be > 0");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sweep_table_formats() {
|
||||
// Verify print_sweep_table doesn't panic with various inputs.
|
||||
print_sweep_table(&[]);
|
||||
|
||||
let results = vec![
|
||||
SweepResult {
|
||||
target_depth: 10,
|
||||
max_depth: 50,
|
||||
frames_sent: 100,
|
||||
frames_received: 98,
|
||||
loss_pct: 2.0,
|
||||
avg_latency_ms: 200.0,
|
||||
underruns: 2,
|
||||
overruns: 0,
|
||||
},
|
||||
SweepResult {
|
||||
target_depth: 25,
|
||||
max_depth: 100,
|
||||
frames_sent: 100,
|
||||
frames_received: 100,
|
||||
loss_pct: 0.0,
|
||||
avg_latency_ms: 500.0,
|
||||
underruns: 0,
|
||||
overruns: 0,
|
||||
},
|
||||
];
|
||||
print_sweep_table(&results);
|
||||
}
|
||||
}
|
||||
190
crates/wzp-client/tests/long_session.rs
Normal file
190
crates/wzp-client/tests/long_session.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
//! WZP-P2-T1-S5: 60-second long-session regression tests.
|
||||
//!
|
||||
//! Verifies that the full codec + FEC + jitter buffer pipeline does not drift
|
||||
//! or degrade over a sustained 60-second (3000-frame) session. Runs entirely
|
||||
//! in-process with no network — packets flow directly from encoder to decoder.
|
||||
|
||||
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||
const SAMPLE_RATE: f32 = 48_000.0;
|
||||
const TOTAL_FRAMES: u64 = 3_000; // 60 seconds at 50 fps
|
||||
|
||||
/// Build a CallConfig tuned for direct-loopback testing (no network).
|
||||
///
|
||||
/// Disables silence suppression and noise suppression (which would mangle
|
||||
/// or squelch the synthetic tone), uses a fixed (non-adaptive) jitter buffer
|
||||
/// with min_depth=1 so that packets are played out as soon as they arrive.
|
||||
fn test_config() -> CallConfig {
|
||||
CallConfig {
|
||||
profile: QualityProfile::GOOD,
|
||||
jitter_target: 4,
|
||||
jitter_max: 500,
|
||||
jitter_min: 1,
|
||||
suppression_enabled: false,
|
||||
noise_suppression: false,
|
||||
adaptive_jitter: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a 20ms frame of 440 Hz sine tone.
|
||||
fn sine_frame(frame_offset: u64) -> Vec<i16> {
|
||||
let start_sample = frame_offset * FRAME_SAMPLES as u64;
|
||||
(0..FRAME_SAMPLES)
|
||||
.map(|i| {
|
||||
let t = (start_sample + i as u64) as f32 / SAMPLE_RATE;
|
||||
(f32::sin(2.0 * std::f32::consts::PI * 440.0 * t) * 16000.0) as i16
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 60-second session with a perfect (lossless, in-order) channel.
|
||||
///
|
||||
/// Encodes 3000 frames of 440 Hz tone, feeds every packet directly into the
|
||||
/// decoder, and verifies:
|
||||
/// - frame loss < 5% (>2850 of 3000 source frames decoded or PLC'd)
|
||||
/// - no panics
|
||||
///
|
||||
/// Note: the encoder shares a single sequence counter between source and
|
||||
/// repair packets. Since repair packets are NOT pushed into the jitter
|
||||
/// buffer, each FEC block creates a gap in the playout sequence. GOOD
|
||||
/// profile (5 frames/block, fec_ratio=0.2) generates 1 repair per block,
|
||||
/// so every 6th seq number is a "phantom" Missing in the jitter buffer.
|
||||
/// The jitter buffer correctly fills these gaps with PLC. We call
|
||||
/// `decode_next` once per encode tick; the buffer stays shallow because
|
||||
/// PLC frames consume the phantom seqs at the same rate they're created.
|
||||
#[test]
|
||||
fn long_session_no_drift() {
|
||||
let config = test_config();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
let mut decoder = CallDecoder::new(&config);
|
||||
|
||||
let mut frames_decoded = 0u64;
|
||||
let mut pcm_buf = vec![0i16; FRAME_SAMPLES];
|
||||
|
||||
for i in 0..TOTAL_FRAMES {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode should not fail");
|
||||
|
||||
for pkt in packets {
|
||||
decoder.ingest(pkt);
|
||||
}
|
||||
|
||||
// Decode one frame per tick (mirrors real-time 50 fps cadence).
|
||||
if decoder.decode_next(&mut pcm_buf).is_some() {
|
||||
frames_decoded += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let stats = decoder.stats();
|
||||
|
||||
println!(
|
||||
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
|
||||
// (some via PLC for repair-seq gaps). Allow up to 5% gap.
|
||||
assert!(
|
||||
frames_decoded > 2850,
|
||||
"frame loss too high: decoded {frames_decoded}/3000 (need >2850 = <5% loss)"
|
||||
);
|
||||
}
|
||||
|
||||
/// 60-second session with simulated 5% packet loss and reordering.
|
||||
///
|
||||
/// Every 20th source packet is dropped; pairs of adjacent packets are swapped
|
||||
/// every 7 frames. Verifies that FEC + jitter buffer recover gracefully:
|
||||
/// - frame loss < 10% (FEC should recover some of the 5% artificial loss)
|
||||
/// - no panics
|
||||
#[test]
|
||||
fn long_session_with_simulated_loss() {
|
||||
let config = test_config();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
let mut decoder = CallDecoder::new(&config);
|
||||
|
||||
let mut frames_decoded = 0u64;
|
||||
let mut pcm_buf = vec![0i16; FRAME_SAMPLES];
|
||||
|
||||
for i in 0..TOTAL_FRAMES {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode should not fail");
|
||||
|
||||
let mut batch: Vec<_> = packets.into_iter().collect();
|
||||
|
||||
// Simulate reordering: swap first two packets in the batch every 7 frames.
|
||||
if i % 7 == 0 && batch.len() >= 2 {
|
||||
batch.swap(0, 1);
|
||||
}
|
||||
|
||||
for (j, pkt) in batch.into_iter().enumerate() {
|
||||
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
|
||||
if !pkt.header.is_repair && i % 20 == 0 && j == 0 {
|
||||
continue; // drop this packet
|
||||
}
|
||||
decoder.ingest(pkt);
|
||||
}
|
||||
|
||||
if decoder.decode_next(&mut pcm_buf).is_some() {
|
||||
frames_decoded += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let stats = decoder.stats();
|
||||
|
||||
println!(
|
||||
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
|
||||
assert!(
|
||||
frames_decoded > 2700,
|
||||
"frame loss too high under simulated loss: decoded {frames_decoded}/3000 (need >2700 = <10%)"
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that the jitter buffer's decoded-frame count is consistent with its
|
||||
/// own internal statistics over a long session.
|
||||
#[test]
|
||||
fn long_session_stats_consistency() {
|
||||
let config = test_config();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
let mut decoder = CallDecoder::new(&config);
|
||||
|
||||
let mut frames_decoded = 0u64;
|
||||
let mut pcm_buf = vec![0i16; FRAME_SAMPLES];
|
||||
|
||||
for i in 0..TOTAL_FRAMES {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
|
||||
for pkt in packets {
|
||||
decoder.ingest(pkt);
|
||||
}
|
||||
if decoder.decode_next(&mut pcm_buf).is_some() {
|
||||
frames_decoded += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let stats = decoder.stats();
|
||||
|
||||
// total_decoded should match our manual counter.
|
||||
assert_eq!(
|
||||
stats.total_decoded, frames_decoded,
|
||||
"stats.total_decoded ({}) != manually counted frames_decoded ({frames_decoded})",
|
||||
stats.total_decoded,
|
||||
);
|
||||
|
||||
// packets_received should be > 0.
|
||||
assert!(
|
||||
stats.packets_received > 0,
|
||||
"stats.packets_received should be > 0"
|
||||
);
|
||||
}
|
||||
@@ -16,4 +16,10 @@ audiopus = { workspace = true }
|
||||
# Pure-Rust Codec2 implementation
|
||||
codec2 = { workspace = true }
|
||||
|
||||
# RNG for comfort noise generation
|
||||
rand = { workspace = true }
|
||||
|
||||
# ML-based noise suppression (pure-Rust port of RNNoise)
|
||||
nnnoiseless = "0.5"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::codec2_dec::Codec2Decoder;
|
||||
use crate::codec2_enc::Codec2Encoder;
|
||||
use crate::opus_dec::OpusDecoder;
|
||||
use crate::opus_enc::OpusEncoder;
|
||||
use crate::resample;
|
||||
use crate::resample::{Downsampler48to8, Upsampler8to48};
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -54,6 +54,7 @@ pub struct AdaptiveEncoder {
|
||||
opus: OpusEncoder,
|
||||
codec2: Codec2Encoder,
|
||||
active: CodecId,
|
||||
downsampler: Downsampler48to8,
|
||||
}
|
||||
|
||||
impl AdaptiveEncoder {
|
||||
@@ -66,6 +67,7 @@ impl AdaptiveEncoder {
|
||||
opus,
|
||||
codec2,
|
||||
active: profile.codec,
|
||||
downsampler: Downsampler48to8::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -74,7 +76,7 @@ impl AudioEncoder for AdaptiveEncoder {
|
||||
fn encode(&mut self, pcm: &[i16], out: &mut [u8]) -> Result<usize, CodecError> {
|
||||
if is_codec2(self.active) {
|
||||
// Downsample 48 kHz → 8 kHz then encode via Codec2.
|
||||
let pcm_8k = resample::resample_48k_to_8k(pcm);
|
||||
let pcm_8k = self.downsampler.process(pcm);
|
||||
self.codec2.encode(&pcm_8k, out)
|
||||
} else {
|
||||
self.opus.encode(pcm, out)
|
||||
@@ -126,6 +128,7 @@ pub struct AdaptiveDecoder {
|
||||
opus: OpusDecoder,
|
||||
codec2: Codec2Decoder,
|
||||
active: CodecId,
|
||||
upsampler: Upsampler8to48,
|
||||
}
|
||||
|
||||
impl AdaptiveDecoder {
|
||||
@@ -138,6 +141,7 @@ impl AdaptiveDecoder {
|
||||
opus,
|
||||
codec2,
|
||||
active: profile.codec,
|
||||
upsampler: Upsampler8to48::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -149,7 +153,7 @@ impl AudioDecoder for AdaptiveDecoder {
|
||||
let c2_samples = self.codec2_frame_samples();
|
||||
let mut buf_8k = vec![0i16; c2_samples];
|
||||
let n = self.codec2.decode(encoded, &mut buf_8k)?;
|
||||
let pcm_48k = resample::resample_8k_to_48k(&buf_8k[..n]);
|
||||
let pcm_48k = self.upsampler.process(&buf_8k[..n]);
|
||||
let out_len = pcm_48k.len().min(pcm.len());
|
||||
pcm[..out_len].copy_from_slice(&pcm_48k[..out_len]);
|
||||
Ok(out_len)
|
||||
@@ -163,7 +167,7 @@ impl AudioDecoder for AdaptiveDecoder {
|
||||
let c2_samples = self.codec2_frame_samples();
|
||||
let mut buf_8k = vec![0i16; c2_samples];
|
||||
let n = self.codec2.decode_lost(&mut buf_8k)?;
|
||||
let pcm_48k = resample::resample_8k_to_48k(&buf_8k[..n]);
|
||||
let pcm_48k = self.upsampler.process(&buf_8k[..n]);
|
||||
let out_len = pcm_48k.len().min(pcm.len());
|
||||
pcm[..out_len].copy_from_slice(&pcm_48k[..out_len]);
|
||||
Ok(out_len)
|
||||
|
||||
335
crates/wzp-codec/src/aec.rs
Normal file
335
crates/wzp-codec/src/aec.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
//! Acoustic Echo Cancellation — delay-compensated leaky NLMS with
|
||||
//! Geigel double-talk detection.
|
||||
//!
|
||||
//! Key insight: on a laptop, the round-trip audio latency (playout → speaker
|
||||
//! → air → mic → capture) is 30–50ms. The far-end reference must be delayed
|
||||
//! by this amount so the adaptive filter models the *echo path*, not the
|
||||
//! *system delay + echo path*.
|
||||
//!
|
||||
//! The leaky coefficient decay prevents the filter from diverging when the
|
||||
//! echo path changes (e.g. hand near laptop) or when the delay estimate
|
||||
//! is slightly off.
|
||||
|
||||
/// Delay-compensated leaky NLMS echo canceller with Geigel DTD.
|
||||
pub struct EchoCanceller {
|
||||
// --- Adaptive filter ---
|
||||
filter: Vec<f32>,
|
||||
filter_len: usize,
|
||||
/// Circular buffer of far-end reference samples (after delay).
|
||||
far_buf: Vec<f32>,
|
||||
far_pos: usize,
|
||||
/// NLMS step size.
|
||||
mu: f32,
|
||||
/// Leakage factor: coefficients are multiplied by (1 - leak) each frame.
|
||||
/// Prevents unbounded growth / divergence. 0.0001 is gentle.
|
||||
leak: f32,
|
||||
enabled: bool,
|
||||
|
||||
// --- Delay buffer ---
|
||||
/// Raw far-end samples before delay compensation.
|
||||
delay_ring: Vec<f32>,
|
||||
delay_write: usize,
|
||||
delay_read: usize,
|
||||
/// Delay in samples (e.g. 1920 = 40ms at 48kHz).
|
||||
delay_samples: usize,
|
||||
/// Capacity of the delay ring.
|
||||
delay_cap: usize,
|
||||
|
||||
// --- Double-talk detection (Geigel) ---
|
||||
/// Peak far-end level over the last filter_len samples.
|
||||
far_peak: f32,
|
||||
/// Geigel threshold: if |near| > threshold * far_peak, assume double-talk.
|
||||
geigel_threshold: f32,
|
||||
/// Holdover counter: keep DTD active for a few frames after detection.
|
||||
dtd_holdover: u32,
|
||||
dtd_hold_frames: u32,
|
||||
}
|
||||
|
||||
impl EchoCanceller {
|
||||
/// Create a new echo canceller.
|
||||
///
|
||||
/// * `sample_rate` — typically 48000
|
||||
/// * `filter_ms` — echo-tail length in milliseconds (60ms recommended)
|
||||
/// * `delay_ms` — far-end delay compensation in milliseconds (40ms for laptops)
|
||||
pub fn new(sample_rate: u32, filter_ms: u32) -> Self {
|
||||
Self::with_delay(sample_rate, filter_ms, 40)
|
||||
}
|
||||
|
||||
pub fn with_delay(sample_rate: u32, filter_ms: u32, delay_ms: u32) -> Self {
|
||||
let filter_len = (sample_rate as usize) * (filter_ms as usize) / 1000;
|
||||
let delay_samples = (sample_rate as usize) * (delay_ms as usize) / 1000;
|
||||
// Delay ring must hold at least delay_samples + one frame (960) of headroom.
|
||||
let delay_cap = delay_samples + (sample_rate as usize / 10); // +100ms headroom
|
||||
Self {
|
||||
filter: vec![0.0; filter_len],
|
||||
filter_len,
|
||||
far_buf: vec![0.0; filter_len],
|
||||
far_pos: 0,
|
||||
mu: 0.01,
|
||||
leak: 0.0001,
|
||||
enabled: true,
|
||||
|
||||
delay_ring: vec![0.0; delay_cap],
|
||||
delay_write: 0,
|
||||
delay_read: 0,
|
||||
delay_samples,
|
||||
delay_cap,
|
||||
|
||||
far_peak: 0.0,
|
||||
geigel_threshold: 0.7,
|
||||
dtd_holdover: 0,
|
||||
dtd_hold_frames: 5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed far-end (speaker) samples. These go into the delay buffer first;
|
||||
/// once enough samples have accumulated, they are released to the filter's
|
||||
/// circular buffer with the correct delay offset.
|
||||
pub fn feed_farend(&mut self, farend: &[i16]) {
|
||||
// Write raw samples into the delay ring.
|
||||
for &s in farend {
|
||||
self.delay_ring[self.delay_write % self.delay_cap] = s as f32;
|
||||
self.delay_write += 1;
|
||||
}
|
||||
|
||||
// Release delayed samples to the filter's far-end buffer.
|
||||
while self.delay_available() >= 1 {
|
||||
let sample = self.delay_ring[self.delay_read % self.delay_cap];
|
||||
self.delay_read += 1;
|
||||
|
||||
self.far_buf[self.far_pos] = sample;
|
||||
self.far_pos = (self.far_pos + 1) % self.filter_len;
|
||||
|
||||
// Track peak far-end level for Geigel DTD.
|
||||
let abs_s = sample.abs();
|
||||
if abs_s > self.far_peak {
|
||||
self.far_peak = abs_s;
|
||||
}
|
||||
}
|
||||
|
||||
// Decay far_peak slowly (avoids stale peak from a loud burst long ago).
|
||||
self.far_peak *= 0.9995;
|
||||
}
|
||||
|
||||
/// Number of delayed samples available to release.
|
||||
fn delay_available(&self) -> usize {
|
||||
let buffered = self.delay_write - self.delay_read;
|
||||
if buffered > self.delay_samples {
|
||||
buffered - self.delay_samples
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a near-end (microphone) frame, removing the estimated echo.
|
||||
pub fn process_frame(&mut self, nearend: &mut [i16]) -> f32 {
|
||||
if !self.enabled {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let n = nearend.len();
|
||||
let fl = self.filter_len;
|
||||
|
||||
// --- Geigel double-talk detection ---
|
||||
// If any near-end sample exceeds threshold * far_peak, assume
|
||||
// the local speaker is active and freeze adaptation.
|
||||
let mut is_doubletalk = self.dtd_holdover > 0;
|
||||
if !is_doubletalk {
|
||||
let threshold_level = self.geigel_threshold * self.far_peak;
|
||||
for &s in nearend.iter() {
|
||||
if (s as f32).abs() > threshold_level && self.far_peak > 100.0 {
|
||||
is_doubletalk = true;
|
||||
self.dtd_holdover = self.dtd_hold_frames;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.dtd_holdover > 0 {
|
||||
self.dtd_holdover -= 1;
|
||||
}
|
||||
|
||||
// Check if far-end is active (otherwise nothing to cancel).
|
||||
let far_active = self.far_peak > 100.0;
|
||||
|
||||
// --- Leaky coefficient decay ---
|
||||
// Applied once per frame for efficiency.
|
||||
let decay = 1.0 - self.leak;
|
||||
for c in self.filter.iter_mut() {
|
||||
*c *= decay;
|
||||
}
|
||||
|
||||
let mut sum_near_sq: f64 = 0.0;
|
||||
let mut sum_err_sq: f64 = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let near_f = nearend[i] as f32;
|
||||
|
||||
// Position of far-end "now" for this near-end sample.
|
||||
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
||||
|
||||
// --- Echo estimation: dot(filter, far_end_window) ---
|
||||
let mut echo_est: f32 = 0.0;
|
||||
let mut power: f32 = 0.0;
|
||||
|
||||
for k in 0..fl {
|
||||
let fe_idx = (base + fl - k) % fl;
|
||||
let fe = self.far_buf[fe_idx];
|
||||
echo_est += self.filter[k] * fe;
|
||||
power += fe * fe;
|
||||
}
|
||||
|
||||
let error = near_f - echo_est;
|
||||
|
||||
// --- NLMS adaptation (only when far-end active & no double-talk) ---
|
||||
if far_active && !is_doubletalk && power > 10.0 {
|
||||
let step = self.mu * error / (power + 1.0);
|
||||
for k in 0..fl {
|
||||
let fe_idx = (base + fl - k) % fl;
|
||||
self.filter[k] += step * self.far_buf[fe_idx];
|
||||
}
|
||||
}
|
||||
|
||||
let out = error.clamp(-32768.0, 32767.0);
|
||||
nearend[i] = out as i16;
|
||||
|
||||
sum_near_sq += (near_f as f64).powi(2);
|
||||
sum_err_sq += (out as f64).powi(2);
|
||||
}
|
||||
|
||||
if sum_err_sq < 1.0 {
|
||||
100.0
|
||||
} else {
|
||||
(sum_near_sq / sum_err_sq).sqrt() as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.enabled = enabled;
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.filter.iter_mut().for_each(|c| *c = 0.0);
|
||||
self.far_buf.iter_mut().for_each(|s| *s = 0.0);
|
||||
self.far_pos = 0;
|
||||
self.far_peak = 0.0;
|
||||
self.delay_ring.iter_mut().for_each(|s| *s = 0.0);
|
||||
self.delay_write = 0;
|
||||
self.delay_read = 0;
|
||||
self.dtd_holdover = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_with_correct_sizes() {
|
||||
let aec = EchoCanceller::with_delay(48000, 60, 40);
|
||||
assert_eq!(aec.filter_len, 2880); // 60ms @ 48kHz
|
||||
assert_eq!(aec.delay_samples, 1920); // 40ms @ 48kHz
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn passthrough_when_disabled() {
|
||||
let mut aec = EchoCanceller::new(48000, 60);
|
||||
aec.set_enabled(false);
|
||||
|
||||
let original: Vec<i16> = (0..960).map(|i| (i * 10) as i16).collect();
|
||||
let mut frame = original.clone();
|
||||
aec.process_frame(&mut frame);
|
||||
assert_eq!(frame, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silence_passthrough() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 30, 0);
|
||||
aec.feed_farend(&vec![0i16; 960]);
|
||||
let mut frame = vec![0i16; 960];
|
||||
aec.process_frame(&mut frame);
|
||||
assert!(frame.iter().all(|&s| s == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduces_echo_with_no_delay() {
|
||||
// Simulate: far-end plays, echo arrives at mic attenuated by ~50%
|
||||
// (realistic — speaker to mic on laptop loses volume).
|
||||
let mut aec = EchoCanceller::with_delay(48000, 10, 0);
|
||||
|
||||
let frame_len = 480;
|
||||
let make_tone = |offset: usize| -> Vec<i16> {
|
||||
(0..frame_len)
|
||||
.map(|i| {
|
||||
let t = (offset + i) as f64 / 48000.0;
|
||||
(5000.0 * (2.0 * std::f64::consts::PI * 300.0 * t).sin()) as i16
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut last_erle = 1.0f32;
|
||||
for frame_idx in 0..100 {
|
||||
let farend = make_tone(frame_idx * frame_len);
|
||||
aec.feed_farend(&farend);
|
||||
|
||||
// Near-end = attenuated copy of far-end (echo at ~50% volume).
|
||||
let mut nearend: Vec<i16> = farend.iter().map(|&s| s / 2).collect();
|
||||
last_erle = aec.process_frame(&mut nearend);
|
||||
}
|
||||
|
||||
assert!(
|
||||
last_erle > 1.0,
|
||||
"expected ERLE > 1.0 after adaptation, got {last_erle}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_nearend_during_doubletalk() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 30, 0);
|
||||
|
||||
let frame_len = 960;
|
||||
let nearend: Vec<i16> = (0..frame_len)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(10000.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Feed silence as far-end (no echo source).
|
||||
aec.feed_farend(&vec![0i16; frame_len]);
|
||||
|
||||
let mut frame = nearend.clone();
|
||||
aec.process_frame(&mut frame);
|
||||
|
||||
let input_energy: f64 = nearend.iter().map(|&s| (s as f64).powi(2)).sum();
|
||||
let output_energy: f64 = frame.iter().map(|&s| (s as f64).powi(2)).sum();
|
||||
let ratio = output_energy / input_energy;
|
||||
|
||||
assert!(
|
||||
ratio > 0.8,
|
||||
"near-end speech should be preserved, energy ratio = {ratio:.3}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delay_buffer_holds_samples() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 10, 20);
|
||||
// 20ms delay = 960 samples @ 48kHz.
|
||||
// After feeding, feed_farend auto-drains available samples to far_buf.
|
||||
// So delay_available() is always 0 after feed_farend returns.
|
||||
// Instead, verify far_pos advances only after the delay is filled.
|
||||
|
||||
// Feed 960 samples (= delay amount). No samples released yet.
|
||||
aec.feed_farend(&vec![1i16; 960]);
|
||||
// far_buf should still be all zeros (nothing released).
|
||||
assert!(aec.far_buf.iter().all(|&s| s == 0.0), "nothing should be released yet");
|
||||
|
||||
// Feed 480 more. 480 should be released to far_buf.
|
||||
aec.feed_farend(&vec![2i16; 480]);
|
||||
let non_zero = aec.far_buf.iter().filter(|&&s| s != 0.0).count();
|
||||
assert!(non_zero > 0, "samples should have been released to far_buf");
|
||||
}
|
||||
}
|
||||
219
crates/wzp-codec/src/agc.rs
Normal file
219
crates/wzp-codec/src/agc.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
//! Automatic Gain Control (AGC) with two-stage smoothing.
|
||||
//!
|
||||
//! Uses a fast attack / slow release envelope follower to keep the
|
||||
//! output signal near a configurable target RMS level. This prevents
|
||||
//! both clipping (when the speaker is too loud) and inaudibility (when
|
||||
//! the speaker is too quiet or far from the mic).
|
||||
|
||||
/// Two-stage automatic gain control.
|
||||
///
|
||||
/// The gain is adjusted per-frame based on the measured RMS energy,
|
||||
/// with a fast attack (gain decreases quickly when signal gets louder)
|
||||
/// and a slow release (gain increases gradually when signal gets quieter).
|
||||
pub struct AutoGainControl {
|
||||
target_rms: f64,
|
||||
current_gain: f64,
|
||||
min_gain: f64,
|
||||
max_gain: f64,
|
||||
attack_alpha: f64,
|
||||
release_alpha: f64,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl AutoGainControl {
|
||||
/// Create a new AGC with sensible VoIP defaults.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
target_rms: 3000.0, // ~-20 dBFS for i16
|
||||
current_gain: 1.0,
|
||||
min_gain: 0.5,
|
||||
max_gain: 32.0,
|
||||
attack_alpha: 0.3, // fast attack
|
||||
release_alpha: 0.02, // slow release
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a frame of PCM audio in-place, applying gain adjustment.
|
||||
pub fn process_frame(&mut self, pcm: &mut [i16]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute RMS of the frame.
|
||||
let rms = Self::compute_rms(pcm);
|
||||
|
||||
// Don't amplify near-silence — it would just boost noise.
|
||||
if rms < 10.0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Desired instantaneous gain.
|
||||
let desired_gain = (self.target_rms / rms).clamp(self.min_gain, self.max_gain);
|
||||
|
||||
// Smooth the gain transition.
|
||||
let alpha = if desired_gain < self.current_gain {
|
||||
// Signal is louder than target → reduce gain quickly (attack).
|
||||
self.attack_alpha
|
||||
} else {
|
||||
// Signal is quieter than target → raise gain slowly (release).
|
||||
self.release_alpha
|
||||
};
|
||||
|
||||
self.current_gain = self.current_gain * (1.0 - alpha) + desired_gain * alpha;
|
||||
|
||||
// Apply gain to each sample with hard limiting at ±31000 (~0.946 * i16::MAX).
|
||||
const LIMIT: f64 = 31000.0;
|
||||
let gain = self.current_gain;
|
||||
for sample in pcm.iter_mut() {
|
||||
let amplified = (*sample as f64) * gain;
|
||||
let clamped = amplified.clamp(-LIMIT, LIMIT);
|
||||
*sample = clamped as i16;
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable or disable the AGC.
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.enabled = enabled;
|
||||
}
|
||||
|
||||
/// Returns whether the AGC is currently enabled.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Current gain expressed in dB.
|
||||
pub fn current_gain_db(&self) -> f64 {
|
||||
20.0 * self.current_gain.log10()
|
||||
}
|
||||
|
||||
/// Compute the RMS (root mean square) of a PCM buffer.
|
||||
fn compute_rms(pcm: &[i16]) -> f64 {
|
||||
if pcm.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum_sq: f64 = pcm.iter().map(|&s| (s as f64) * (s as f64)).sum();
|
||||
(sum_sq / pcm.len() as f64).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AutoGainControl {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn agc_creates_with_defaults() {
|
||||
let agc = AutoGainControl::new();
|
||||
assert!(agc.is_enabled());
|
||||
assert!((agc.current_gain - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_passthrough_when_disabled() {
|
||||
let mut agc = AutoGainControl::new();
|
||||
agc.set_enabled(false);
|
||||
|
||||
let original: Vec<i16> = (0..960).map(|i| (i * 5) as i16).collect();
|
||||
let mut frame = original.clone();
|
||||
agc.process_frame(&mut frame);
|
||||
|
||||
assert_eq!(frame, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_does_not_amplify_silence() {
|
||||
let mut agc = AutoGainControl::new();
|
||||
let mut frame = vec![0i16; 960];
|
||||
agc.process_frame(&mut frame);
|
||||
assert!(frame.iter().all(|&s| s == 0));
|
||||
// Gain should remain at initial value.
|
||||
assert!((agc.current_gain - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_amplifies_quiet_signal() {
|
||||
let mut agc = AutoGainControl::new();
|
||||
|
||||
// Very quiet signal (RMS ~ 50).
|
||||
let mut frame: Vec<i16> = (0..960)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(50.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Process several frames to let the gain ramp up.
|
||||
for _ in 0..50 {
|
||||
let mut f = frame.clone();
|
||||
agc.process_frame(&mut f);
|
||||
frame = f;
|
||||
}
|
||||
|
||||
// Gain should have increased past 1.0.
|
||||
assert!(
|
||||
agc.current_gain > 1.05,
|
||||
"expected gain > 1.05 for quiet signal, got {}",
|
||||
agc.current_gain
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_attenuates_loud_signal() {
|
||||
let mut agc = AutoGainControl::new();
|
||||
|
||||
// Loud signal (RMS ~ 20000).
|
||||
let frame: Vec<i16> = (0..960)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(28000.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Process several frames.
|
||||
for _ in 0..20 {
|
||||
let mut f = frame.clone();
|
||||
agc.process_frame(&mut f);
|
||||
}
|
||||
|
||||
// Gain should have decreased below 1.0.
|
||||
assert!(
|
||||
agc.current_gain < 1.0,
|
||||
"expected gain < 1.0 for loud signal, got {}",
|
||||
agc.current_gain
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_output_within_limits() {
|
||||
let mut agc = AutoGainControl::new();
|
||||
// Force a high gain by processing many quiet frames first.
|
||||
for _ in 0..100 {
|
||||
let mut f: Vec<i16> = vec![100; 960];
|
||||
agc.process_frame(&mut f);
|
||||
}
|
||||
|
||||
// Now send a louder frame — output should still be within ±31000.
|
||||
let mut frame: Vec<i16> = vec![20000; 960];
|
||||
agc.process_frame(&mut frame);
|
||||
assert!(
|
||||
frame.iter().all(|&s| s.abs() <= 31000),
|
||||
"output samples must be within ±31000"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agc_gain_db_at_unity() {
|
||||
let agc = AutoGainControl::new();
|
||||
let db = agc.current_gain_db();
|
||||
assert!(
|
||||
db.abs() < 0.01,
|
||||
"expected ~0 dB at unity gain, got {db}"
|
||||
);
|
||||
}
|
||||
}
|
||||
183
crates/wzp-codec/src/denoise.rs
Normal file
183
crates/wzp-codec/src/denoise.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
//! ML-based noise suppression using nnnoiseless (pure-Rust RNNoise port).
|
||||
//!
|
||||
//! RNNoise operates on 480-sample frames at 48 kHz (10 ms). Our codec pipeline
|
||||
//! uses 960-sample frames (20 ms), so each call processes two halves.
|
||||
|
||||
use nnnoiseless::DenoiseState;
|
||||
|
||||
/// Wraps [`DenoiseState`] to provide noise suppression on 960-sample (20 ms) PCM
|
||||
/// frames at 48 kHz.
|
||||
pub struct NoiseSupressor {
|
||||
state: Box<DenoiseState<'static>>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl NoiseSupressor {
|
||||
/// Create a new noise suppressor (enabled by default).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: DenoiseState::new(),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a 960-sample frame of 48 kHz mono PCM **in place**.
|
||||
///
|
||||
/// nnnoiseless expects f32 samples in the range roughly [-32768, 32767].
|
||||
/// We convert i16 → f32, process two 480-sample halves, then convert back.
|
||||
pub fn process(&mut self, pcm: &mut [i16]) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
pcm.len() >= 960,
|
||||
"NoiseSupressor::process expects at least 960 samples, got {}",
|
||||
pcm.len()
|
||||
);
|
||||
|
||||
// Process in two 480-sample halves.
|
||||
for half in 0..2 {
|
||||
let offset = half * 480;
|
||||
let end = offset + 480;
|
||||
if end > pcm.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
// i16 → f32
|
||||
let mut float_buf = [0.0f32; 480];
|
||||
for (i, &sample) in pcm[offset..end].iter().enumerate() {
|
||||
float_buf[i] = sample as f32;
|
||||
}
|
||||
|
||||
// nnnoiseless processes in-place, returns VAD probability (unused here).
|
||||
let mut output = [0.0f32; 480];
|
||||
let _vad = self.state.process_frame(&mut output, &float_buf);
|
||||
|
||||
// f32 → i16 with clamping
|
||||
for (i, &val) in output.iter().enumerate() {
|
||||
let clamped = val.max(-32768.0).min(32767.0);
|
||||
pcm[offset + i] = clamped as i16;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable or disable noise suppression.
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.enabled = enabled;
|
||||
}
|
||||
|
||||
/// Returns `true` if noise suppression is currently enabled.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoiseSupressor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn denoiser_creates() {
|
||||
let ns = NoiseSupressor::new();
|
||||
assert!(ns.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn denoiser_processes_frame() {
|
||||
let mut ns = NoiseSupressor::new();
|
||||
let mut pcm = vec![0i16; 960];
|
||||
// Fill with a simple pattern so we have something to process.
|
||||
for (i, s) in pcm.iter_mut().enumerate() {
|
||||
*s = ((i % 100) as i16).wrapping_mul(100);
|
||||
}
|
||||
let original_len = pcm.len();
|
||||
ns.process(&mut pcm);
|
||||
assert_eq!(pcm.len(), original_len, "output length must match input length");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn denoiser_reduces_noise() {
|
||||
let mut ns = NoiseSupressor::new();
|
||||
|
||||
// Generate a 440 Hz sine tone + white noise at 48 kHz.
|
||||
// We need multiple frames for the RNN to converge.
|
||||
let sample_rate = 48000.0f64;
|
||||
let freq = 440.0f64;
|
||||
let amplitude = 10000.0f64;
|
||||
let noise_amplitude = 3000.0f64;
|
||||
|
||||
// Use a simple PRNG for reproducibility.
|
||||
let mut rng_state: u32 = 12345;
|
||||
let mut next_noise = || -> f64 {
|
||||
// xorshift32
|
||||
rng_state ^= rng_state << 13;
|
||||
rng_state ^= rng_state >> 17;
|
||||
rng_state ^= rng_state << 5;
|
||||
// Map to [-1, 1]
|
||||
(rng_state as f64 / u32::MAX as f64) * 2.0 - 1.0
|
||||
};
|
||||
|
||||
// Feed several frames to let the RNN warm up, then measure the last one.
|
||||
let num_warmup_frames = 20;
|
||||
let mut last_input = vec![0i16; 960];
|
||||
let mut last_output = vec![0i16; 960];
|
||||
|
||||
for frame_idx in 0..=num_warmup_frames {
|
||||
let mut pcm = vec![0i16; 960];
|
||||
for (i, s) in pcm.iter_mut().enumerate() {
|
||||
let t = (frame_idx * 960 + i) as f64 / sample_rate;
|
||||
let sine = amplitude * (2.0 * std::f64::consts::PI * freq * t).sin();
|
||||
let noise = noise_amplitude * next_noise();
|
||||
*s = (sine + noise).max(-32768.0).min(32767.0) as i16;
|
||||
}
|
||||
|
||||
if frame_idx == num_warmup_frames {
|
||||
last_input = pcm.clone();
|
||||
}
|
||||
|
||||
ns.process(&mut pcm);
|
||||
|
||||
if frame_idx == num_warmup_frames {
|
||||
last_output = pcm;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute RMS of input and output.
|
||||
let rms = |buf: &[i16]| -> f64 {
|
||||
let sum: f64 = buf.iter().map(|&s| (s as f64) * (s as f64)).sum();
|
||||
(sum / buf.len() as f64).sqrt()
|
||||
};
|
||||
|
||||
let input_rms = rms(&last_input);
|
||||
let output_rms = rms(&last_output);
|
||||
|
||||
// The denoiser should not amplify the signal beyond input.
|
||||
// More importantly, the output should have measurably lower noise.
|
||||
// We verify the output RMS is less than the input RMS (noise was reduced).
|
||||
assert!(
|
||||
output_rms < input_rms,
|
||||
"expected output RMS ({output_rms:.1}) < input RMS ({input_rms:.1}); \
|
||||
denoiser should reduce noise"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn denoiser_passthrough_when_disabled() {
|
||||
let mut ns = NoiseSupressor::new();
|
||||
ns.set_enabled(false);
|
||||
assert!(!ns.is_enabled());
|
||||
|
||||
let original: Vec<i16> = (0..960).map(|i| (i * 10) as i16).collect();
|
||||
let mut pcm = original.clone();
|
||||
ns.process(&mut pcm);
|
||||
|
||||
assert_eq!(pcm, original, "disabled denoiser must not alter input");
|
||||
}
|
||||
}
|
||||
@@ -10,13 +10,21 @@
|
||||
//! trait-object encoders/decoders that handle adaptive switching internally.
|
||||
|
||||
pub mod adaptive;
|
||||
pub mod aec;
|
||||
pub mod agc;
|
||||
pub mod codec2_dec;
|
||||
pub mod codec2_enc;
|
||||
pub mod denoise;
|
||||
pub mod opus_dec;
|
||||
pub mod opus_enc;
|
||||
pub mod resample;
|
||||
pub mod silence;
|
||||
|
||||
pub use adaptive::{AdaptiveDecoder, AdaptiveEncoder};
|
||||
pub use aec::EchoCanceller;
|
||||
pub use agc::AutoGainControl;
|
||||
pub use denoise::NoiseSupressor;
|
||||
pub use silence::{ComfortNoise, SilenceDetector};
|
||||
pub use wzp_proto::{AudioDecoder, AudioEncoder, CodecId, QualityProfile};
|
||||
|
||||
/// Create an adaptive encoder starting at the given quality profile.
|
||||
|
||||
@@ -40,6 +40,11 @@ impl OpusEncoder {
|
||||
.set_signal(Signal::Voice)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set signal: {e}")))?;
|
||||
|
||||
// Default complexity 7 — good quality/CPU trade-off for VoIP
|
||||
enc.inner
|
||||
.set_complexity(7)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set complexity: {e}")))?;
|
||||
|
||||
Ok(enc)
|
||||
}
|
||||
|
||||
@@ -56,6 +61,21 @@ impl OpusEncoder {
|
||||
pub fn frame_samples(&self) -> usize {
|
||||
(48_000 * self.frame_duration_ms as usize) / 1000
|
||||
}
|
||||
|
||||
/// Set the encoder complexity (0-10). Higher values produce better quality
|
||||
/// at the cost of more CPU. Default is 7.
|
||||
pub fn set_complexity(&mut self, complexity: i32) {
|
||||
let c = (complexity as u8).min(10);
|
||||
let _ = self.inner.set_complexity(c);
|
||||
}
|
||||
|
||||
/// Hint the encoder about expected packet loss percentage (0-100).
|
||||
///
|
||||
/// Higher values cause the encoder to use more redundancy to survive
|
||||
/// packet loss, at the expense of slightly higher bitrate.
|
||||
pub fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||
let _ = self.inner.set_packet_loss_perc(loss_pct.min(100));
|
||||
}
|
||||
}
|
||||
|
||||
impl AudioEncoder for OpusEncoder {
|
||||
|
||||
@@ -1,55 +1,258 @@
|
||||
//! Simple linear resampler for 48 kHz <-> 8 kHz conversion.
|
||||
//! Windowed-sinc FIR resampler for 48 kHz <-> 8 kHz conversion.
|
||||
//!
|
||||
//! These are basic implementations suitable for voice. For higher quality,
|
||||
//! replace with the `rubato` crate later.
|
||||
//! Provides both stateless free functions (backward-compatible) and stateful
|
||||
//! `Downsampler48to8` / `Upsampler8to48` structs that maintain overlap history
|
||||
//! between frames for glitch-free streaming.
|
||||
|
||||
/// Downsample from 48 kHz to 8 kHz (6:1 decimation with averaging).
|
||||
use std::f64::consts::PI;
|
||||
|
||||
// ─── FIR kernel parameters ─────────────────────────────────────────────────
|
||||
|
||||
/// Number of FIR taps in the anti-alias / interpolation filter.
|
||||
const FIR_TAPS: usize = 48;
|
||||
/// Kaiser window beta parameter — controls sidelobe attenuation.
|
||||
const KAISER_BETA: f64 = 8.0;
|
||||
/// Cutoff frequency in Hz for the low-pass filter (just below 4 kHz Nyquist of 8 kHz).
|
||||
const CUTOFF_HZ: f64 = 3800.0;
|
||||
/// Working sample rate in Hz.
|
||||
const SAMPLE_RATE: f64 = 48000.0;
|
||||
/// Decimation / interpolation ratio between 48 kHz and 8 kHz.
|
||||
const RATIO: usize = 6;
|
||||
|
||||
// ─── Kaiser window helpers ─────────────────────────────────────────────────
|
||||
|
||||
/// Zeroth-order modified Bessel function of the first kind, I₀(x).
|
||||
///
|
||||
/// Each output sample is the average of 6 consecutive input samples,
|
||||
/// providing basic anti-aliasing via a box filter.
|
||||
pub fn resample_48k_to_8k(input: &[i16]) -> Vec<i16> {
|
||||
const RATIO: usize = 6;
|
||||
let out_len = input.len() / RATIO;
|
||||
let mut output = Vec::with_capacity(out_len);
|
||||
|
||||
for chunk in input.chunks_exact(RATIO) {
|
||||
let sum: i32 = chunk.iter().map(|&s| s as i32).sum();
|
||||
output.push((sum / RATIO as i32) as i16);
|
||||
/// Computed via the well-known power-series expansion, converging rapidly
|
||||
/// for the moderate values of x used in Kaiser window design.
|
||||
fn bessel_i0(x: f64) -> f64 {
|
||||
let mut sum = 1.0f64;
|
||||
let mut term = 1.0f64;
|
||||
let half_x = x / 2.0;
|
||||
for k in 1..=25 {
|
||||
term *= (half_x / k as f64) * (half_x / k as f64);
|
||||
sum += term;
|
||||
if term < 1e-12 * sum {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
sum
|
||||
}
|
||||
|
||||
/// Upsample from 8 kHz to 48 kHz (1:6 interpolation with linear interp).
|
||||
/// Build a windowed-sinc low-pass FIR kernel.
|
||||
///
|
||||
/// Linearly interpolates between each pair of input samples to produce
|
||||
/// 6 output samples per input sample.
|
||||
pub fn resample_8k_to_48k(input: &[i16]) -> Vec<i16> {
|
||||
const RATIO: usize = 6;
|
||||
if input.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
/// Returns `FIR_TAPS` coefficients normalised so that the DC gain is exactly 1.0.
|
||||
fn build_fir_kernel() -> [f64; FIR_TAPS] {
|
||||
let mut kernel = [0.0f64; FIR_TAPS];
|
||||
let m = (FIR_TAPS - 1) as f64;
|
||||
let fc = CUTOFF_HZ / SAMPLE_RATE; // normalised cutoff (0..0.5)
|
||||
let beta_denom = bessel_i0(KAISER_BETA);
|
||||
|
||||
let out_len = input.len() * RATIO;
|
||||
let mut output = Vec::with_capacity(out_len);
|
||||
|
||||
for i in 0..input.len() {
|
||||
let current = input[i] as i32;
|
||||
let next = if i + 1 < input.len() {
|
||||
input[i + 1] as i32
|
||||
for i in 0..FIR_TAPS {
|
||||
// Sinc
|
||||
let n = i as f64 - m / 2.0;
|
||||
let sinc = if n.abs() < 1e-12 {
|
||||
2.0 * fc
|
||||
} else {
|
||||
current // hold last sample
|
||||
(2.0 * PI * fc * n).sin() / (PI * n)
|
||||
};
|
||||
|
||||
for j in 0..RATIO {
|
||||
let interp = current + (next - current) * j as i32 / RATIO as i32;
|
||||
output.push(interp as i16);
|
||||
// Kaiser window
|
||||
let t = 2.0 * i as f64 / m - 1.0; // range [-1, 1]
|
||||
let kaiser = bessel_i0(KAISER_BETA * (1.0 - t * t).max(0.0).sqrt()) / beta_denom;
|
||||
|
||||
kernel[i] = sinc * kaiser;
|
||||
}
|
||||
|
||||
// Normalise to unity DC gain.
|
||||
let sum: f64 = kernel.iter().sum();
|
||||
if sum.abs() > 1e-15 {
|
||||
for k in kernel.iter_mut() {
|
||||
*k /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
kernel
|
||||
}
|
||||
|
||||
// ─── Stateful Downsampler 48→8 ─────────────────────────────────────────────
|
||||
|
||||
/// Stateful polyphase FIR downsampler from 48 kHz to 8 kHz.
|
||||
///
|
||||
/// Maintains `FIR_TAPS - 1` samples of history between successive calls to
|
||||
/// `process()` for seamless frame boundaries.
|
||||
pub struct Downsampler48to8 {
|
||||
kernel: [f64; FIR_TAPS],
|
||||
history: Vec<f64>,
|
||||
}
|
||||
|
||||
impl Downsampler48to8 {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
kernel: build_fir_kernel(),
|
||||
history: vec![0.0; FIR_TAPS - 1],
|
||||
}
|
||||
}
|
||||
|
||||
/// Downsample a block of 48 kHz samples to 8 kHz.
|
||||
///
|
||||
/// The input length should be a multiple of 6; any trailing samples that
|
||||
/// don't form a complete output sample are consumed into the history.
|
||||
pub fn process(&mut self, input: &[i16]) -> Vec<i16> {
|
||||
let hist_len = self.history.len(); // FIR_TAPS - 1
|
||||
let total_len = hist_len + input.len();
|
||||
|
||||
// Build a working buffer: history ++ input (as f64).
|
||||
let mut work = Vec::with_capacity(total_len);
|
||||
work.extend_from_slice(&self.history);
|
||||
work.extend(input.iter().map(|&s| s as f64));
|
||||
|
||||
let out_len = input.len() / RATIO;
|
||||
let mut output = Vec::with_capacity(out_len);
|
||||
|
||||
for i in 0..out_len {
|
||||
// The centre of the filter for output sample i sits at
|
||||
// position hist_len + i*RATIO in the work buffer (aligning
|
||||
// with the first new input sample at decimation phase 0).
|
||||
let centre = hist_len + i * RATIO;
|
||||
let start = centre + 1 - FIR_TAPS; // may be 0 for the first few
|
||||
|
||||
let mut acc = 0.0f64;
|
||||
for k in 0..FIR_TAPS {
|
||||
let idx = start + k;
|
||||
if idx < work.len() {
|
||||
acc += work[idx] * self.kernel[k];
|
||||
}
|
||||
}
|
||||
output.push(acc.round().clamp(-32768.0, 32767.0) as i16);
|
||||
}
|
||||
|
||||
// Update history: keep the last (FIR_TAPS - 1) samples from work.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
// Input was shorter than history — shift.
|
||||
let shift = hist_len - work.len();
|
||||
self.history.copy_within(shift.., 0);
|
||||
for (i, &v) in work.iter().enumerate() {
|
||||
self.history[hist_len - work.len() + i] = v;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Downsampler48to8 {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Stateful Upsampler 8→48 ───────────────────────────────────────────────
|
||||
|
||||
/// Stateful FIR upsampler from 8 kHz to 48 kHz.
|
||||
///
|
||||
/// Inserts zeros between input samples (zero-stuffing), then applies the
|
||||
/// low-pass FIR to remove imaging, with gain compensation of `RATIO`.
|
||||
pub struct Upsampler8to48 {
|
||||
kernel: [f64; FIR_TAPS],
|
||||
history: Vec<f64>,
|
||||
}
|
||||
|
||||
impl Upsampler8to48 {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
kernel: build_fir_kernel(),
|
||||
history: vec![0.0; FIR_TAPS - 1],
|
||||
}
|
||||
}
|
||||
|
||||
/// Upsample a block of 8 kHz samples to 48 kHz.
|
||||
pub fn process(&mut self, input: &[i16]) -> Vec<i16> {
|
||||
let hist_len = self.history.len(); // FIR_TAPS - 1
|
||||
|
||||
// Zero-stuff: insert RATIO-1 zeros between each input sample.
|
||||
let stuffed_len = input.len() * RATIO;
|
||||
let total_len = hist_len + stuffed_len;
|
||||
|
||||
let mut work = Vec::with_capacity(total_len);
|
||||
work.extend_from_slice(&self.history);
|
||||
for &s in input {
|
||||
work.push(s as f64);
|
||||
for _ in 1..RATIO {
|
||||
work.push(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
let out_len = stuffed_len;
|
||||
let mut output = Vec::with_capacity(out_len);
|
||||
|
||||
// The gain factor compensates for the zeros introduced by stuffing.
|
||||
let gain = RATIO as f64;
|
||||
|
||||
for i in 0..out_len {
|
||||
let centre = hist_len + i;
|
||||
let start = centre + 1 - FIR_TAPS;
|
||||
|
||||
let mut acc = 0.0f64;
|
||||
for k in 0..FIR_TAPS {
|
||||
let idx = start + k;
|
||||
if idx < work.len() {
|
||||
acc += work[idx] * self.kernel[k];
|
||||
}
|
||||
}
|
||||
acc *= gain;
|
||||
output.push(acc.round().clamp(-32768.0, 32767.0) as i16);
|
||||
}
|
||||
|
||||
// Update history.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
let shift = hist_len - work.len();
|
||||
self.history.copy_within(shift.., 0);
|
||||
for (i, &v) in work.iter().enumerate() {
|
||||
self.history[hist_len - work.len() + i] = v;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Upsampler8to48 {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Backward-compatible free functions ─────────────────────────────────────
|
||||
|
||||
/// Downsample from 48 kHz to 8 kHz (6:1 decimation with FIR anti-alias filter).
|
||||
///
|
||||
/// This is a convenience wrapper that creates a temporary [`Downsampler48to8`].
|
||||
/// For streaming use, prefer the stateful struct to avoid edge artefacts between
|
||||
/// frames.
|
||||
pub fn resample_48k_to_8k(input: &[i16]) -> Vec<i16> {
|
||||
let mut ds = Downsampler48to8::new();
|
||||
ds.process(input)
|
||||
}
|
||||
|
||||
/// Upsample from 8 kHz to 48 kHz (1:6 interpolation with FIR imaging filter).
|
||||
///
|
||||
/// This is a convenience wrapper that creates a temporary [`Upsampler8to48`].
|
||||
/// For streaming use, prefer the stateful struct to avoid edge artefacts between
|
||||
/// frames.
|
||||
pub fn resample_8k_to_48k(input: &[i16]) -> Vec<i16> {
|
||||
let mut us = Upsampler8to48::new();
|
||||
us.process(input)
|
||||
}
|
||||
|
||||
// ─── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -66,12 +269,28 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn dc_signal_preserved() {
|
||||
// A constant signal should survive resampling
|
||||
// A constant signal should survive resampling (approximately).
|
||||
let input = vec![1000i16; 960];
|
||||
let down = resample_48k_to_8k(&input);
|
||||
assert!(down.iter().all(|&s| s == 1000));
|
||||
// Allow some edge transient — check that the middle samples are close.
|
||||
let mid_start = down.len() / 4;
|
||||
let mid_end = 3 * down.len() / 4;
|
||||
for &s in &down[mid_start..mid_end] {
|
||||
assert!(
|
||||
(s - 1000).abs() < 50,
|
||||
"DC downsampled sample {s} too far from 1000"
|
||||
);
|
||||
}
|
||||
|
||||
let up = resample_8k_to_48k(&down);
|
||||
assert!(up.iter().all(|&s| s == 1000));
|
||||
let mid_start_up = up.len() / 4;
|
||||
let mid_end_up = 3 * up.len() / 4;
|
||||
for &s in &up[mid_start_up..mid_end_up] {
|
||||
assert!(
|
||||
(s - 1000).abs() < 100,
|
||||
"DC upsampled sample {s} too far from 1000"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -79,4 +298,40 @@ mod tests {
|
||||
assert!(resample_48k_to_8k(&[]).is_empty());
|
||||
assert!(resample_8k_to_48k(&[]).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stateful_downsampler_produces_correct_length() {
|
||||
let mut ds = Downsampler48to8::new();
|
||||
let out = ds.process(&vec![0i16; 960]);
|
||||
assert_eq!(out.len(), 160);
|
||||
let out2 = ds.process(&vec![0i16; 960]);
|
||||
assert_eq!(out2.len(), 160);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stateful_upsampler_produces_correct_length() {
|
||||
let mut us = Upsampler8to48::new();
|
||||
let out = us.process(&vec![0i16; 160]);
|
||||
assert_eq!(out.len(), 960);
|
||||
let out2 = us.process(&vec![0i16; 160]);
|
||||
assert_eq!(out2.len(), 960);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fir_kernel_has_unity_dc_gain() {
|
||||
let kernel = build_fir_kernel();
|
||||
let sum: f64 = kernel.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-10,
|
||||
"FIR kernel DC gain should be 1.0, got {sum}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bessel_i0_known_values() {
|
||||
// I₀(0) = 1
|
||||
assert!((bessel_i0(0.0) - 1.0).abs() < 1e-12);
|
||||
// I₀(1) ≈ 1.2660658
|
||||
assert!((bessel_i0(1.0) - 1.2660658).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
191
crates/wzp-codec/src/silence.rs
Normal file
191
crates/wzp-codec/src/silence.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
//! Silence suppression and comfort noise generation.
|
||||
//!
|
||||
//! During silent periods (~50% of a typical call), full encoded frames waste
|
||||
//! bandwidth. [`SilenceDetector`] detects silent audio based on RMS energy,
|
||||
//! and [`ComfortNoise`] generates low-level background noise to fill gaps on
|
||||
//! the decoder side.
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
/// Detects silence in PCM audio using RMS energy with a hangover period.
|
||||
///
|
||||
/// The hangover prevents clipping the onset of speech: after silence is first
|
||||
/// detected, the detector continues reporting "not silent" for `hangover_frames`
|
||||
/// additional frames before transitioning to suppression.
|
||||
pub struct SilenceDetector {
|
||||
/// RMS threshold below which audio is considered silent (for i16 samples).
|
||||
threshold_rms: f64,
|
||||
/// Number of frames to keep sending after silence starts (prevents speech clipping).
|
||||
hangover_frames: u32,
|
||||
/// Count of consecutive frames whose RMS is below the threshold.
|
||||
silent_frames: u32,
|
||||
/// Whether suppression is currently active.
|
||||
is_suppressing: bool,
|
||||
}
|
||||
|
||||
impl SilenceDetector {
|
||||
/// Create a new silence detector.
|
||||
///
|
||||
/// * `threshold_rms` — RMS energy below which a frame is silent (default: 100.0 for i16).
|
||||
/// * `hangover_frames` — frames to keep sending after silence onset (default: 5 = 100ms at 20ms frames).
|
||||
pub fn new(threshold_rms: f64, hangover_frames: u32) -> Self {
|
||||
Self {
|
||||
threshold_rms,
|
||||
hangover_frames,
|
||||
silent_frames: 0,
|
||||
is_suppressing: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the RMS (root mean square) energy of a PCM buffer.
|
||||
pub fn rms(pcm: &[i16]) -> f64 {
|
||||
if pcm.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum_sq: f64 = pcm.iter().map(|&s| (s as f64) * (s as f64)).sum();
|
||||
(sum_sq / pcm.len() as f64).sqrt()
|
||||
}
|
||||
|
||||
/// Returns `true` if the frame should be suppressed (i.e. is silence past
|
||||
/// the hangover period).
|
||||
///
|
||||
/// Call once per frame. The detector tracks consecutive silent frames
|
||||
/// internally and only reports suppression after the hangover expires.
|
||||
pub fn is_silent(&mut self, pcm: &[i16]) -> bool {
|
||||
let energy = Self::rms(pcm);
|
||||
|
||||
if energy < self.threshold_rms {
|
||||
self.silent_frames = self.silent_frames.saturating_add(1);
|
||||
|
||||
if self.silent_frames > self.hangover_frames {
|
||||
self.is_suppressing = true;
|
||||
}
|
||||
} else {
|
||||
// Speech detected — reset.
|
||||
self.silent_frames = 0;
|
||||
self.is_suppressing = false;
|
||||
}
|
||||
|
||||
self.is_suppressing
|
||||
}
|
||||
|
||||
/// Whether the detector is currently in the suppressing state.
|
||||
pub fn suppressing(&self) -> bool {
|
||||
self.is_suppressing
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates low-level comfort noise to fill silent periods.
|
||||
///
|
||||
/// When the decoder receives a comfort-noise descriptor (or detects a gap
|
||||
/// caused by silence suppression), it uses this to produce a natural-sounding
|
||||
/// background hiss instead of dead silence.
|
||||
pub struct ComfortNoise {
|
||||
/// Peak amplitude of the generated noise (default: 50).
|
||||
level: i16,
|
||||
}
|
||||
|
||||
impl ComfortNoise {
|
||||
/// Create a comfort noise generator with the given amplitude level.
|
||||
pub fn new(level: i16) -> Self {
|
||||
Self { level }
|
||||
}
|
||||
|
||||
/// Fill `pcm` with low-level random noise in the range `[-level, level]`.
|
||||
pub fn generate(&self, pcm: &mut [i16]) {
|
||||
let mut rng = rand::thread_rng();
|
||||
for sample in pcm.iter_mut() {
|
||||
*sample = rng.gen_range(-self.level..=self.level);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn silence_detector_detects_silence() {
|
||||
let mut det = SilenceDetector::new(100.0, 5);
|
||||
let silence = vec![0i16; 960];
|
||||
|
||||
// First 5 frames are hangover — should NOT suppress yet.
|
||||
for _ in 0..5 {
|
||||
assert!(!det.is_silent(&silence));
|
||||
}
|
||||
// Frame 6 onward: past hangover, should suppress.
|
||||
assert!(det.is_silent(&silence));
|
||||
assert!(det.is_silent(&silence));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silence_detector_detects_speech() {
|
||||
let mut det = SilenceDetector::new(100.0, 5);
|
||||
|
||||
// Generate a 1kHz sine wave at decent amplitude.
|
||||
let pcm: Vec<i16> = (0..960)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(10000.0 * (2.0 * std::f64::consts::PI * 1000.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Should never report silent.
|
||||
for _ in 0..20 {
|
||||
assert!(!det.is_silent(&pcm));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silence_detector_hangover() {
|
||||
let mut det = SilenceDetector::new(100.0, 3);
|
||||
let silence = vec![0i16; 960];
|
||||
let speech: Vec<i16> = (0..960)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(5000.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Feed silence past hangover to enter suppression.
|
||||
for _ in 0..4 {
|
||||
det.is_silent(&silence);
|
||||
}
|
||||
assert!(det.is_silent(&silence), "should be suppressing after hangover");
|
||||
|
||||
// Speech arrives — should immediately stop suppressing.
|
||||
assert!(!det.is_silent(&speech));
|
||||
assert!(!det.is_silent(&speech));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comfort_noise_generates_nonzero() {
|
||||
let cn = ComfortNoise::new(50);
|
||||
let mut pcm = vec![0i16; 960];
|
||||
cn.generate(&mut pcm);
|
||||
|
||||
// At least some samples should be non-zero.
|
||||
assert!(pcm.iter().any(|&s| s != 0), "CN output should not be all zeros");
|
||||
|
||||
// All samples should be within [-50, 50].
|
||||
assert!(pcm.iter().all(|&s| s.abs() <= 50), "CN samples out of range");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_calculation() {
|
||||
// All zeros → RMS 0.
|
||||
assert_eq!(SilenceDetector::rms(&[0i16; 100]), 0.0);
|
||||
|
||||
// Constant value: RMS of [v, v, v, ...] = |v|.
|
||||
let pcm = vec![100i16; 100];
|
||||
let rms = SilenceDetector::rms(&pcm);
|
||||
assert!((rms - 100.0).abs() < 0.01, "RMS of constant 100 should be 100, got {rms}");
|
||||
|
||||
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
|
||||
let rms2 = SilenceDetector::rms(&[3, 4]);
|
||||
assert!((rms2 - 3.5355).abs() < 0.01, "RMS of [3,4] should be ~3.5355, got {rms2}");
|
||||
|
||||
// Empty buffer → 0.
|
||||
assert_eq!(SilenceDetector::rms(&[]), 0.0);
|
||||
}
|
||||
}
|
||||
@@ -15,5 +15,18 @@ hkdf = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
bip39 = "2"
|
||||
hex = "0.4"
|
||||
|
||||
# featherChat identity — the source of truth for Seed, IdentityKeyPair, Fingerprint
|
||||
warzone-protocol = { path = "../../deps/featherchat/warzone/crates/warzone-protocol" }
|
||||
|
||||
[dev-dependencies]
|
||||
ed25519-dalek = { workspace = true }
|
||||
warzone-protocol = { path = "../../deps/featherchat/warzone/crates/warzone-protocol" }
|
||||
wzp-proto = { workspace = true }
|
||||
wzp-client = { path = "../wzp-client" }
|
||||
wzp-relay = { path = "../wzp-relay" }
|
||||
serde_json = "1"
|
||||
serde = { workspace = true }
|
||||
bincode = "1"
|
||||
|
||||
281
crates/wzp-crypto/src/identity.rs
Normal file
281
crates/wzp-crypto/src/identity.rs
Normal file
@@ -0,0 +1,281 @@
|
||||
//! featherChat-compatible identity module.
|
||||
//!
|
||||
//! Mirrors `warzone-protocol/src/identity.rs` and `warzone-protocol/src/mnemonic.rs`
|
||||
//! from featherChat. Same seed → same keys → same fingerprint in both codebases.
|
||||
//!
|
||||
//! Source of truth: deps/featherchat/warzone/crates/warzone-protocol/src/identity.rs
|
||||
|
||||
use ed25519_dalek::{SigningKey, VerifyingKey};
|
||||
use hkdf::Hkdf;
|
||||
use sha2::{Digest, Sha256};
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
/// The root secret — 32 bytes from which all keys are derived.
|
||||
/// Displayed to users as a BIP39 mnemonic (24 words).
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::identity::Seed`
|
||||
pub struct Seed(pub [u8; 32]);
|
||||
|
||||
impl Seed {
|
||||
/// Generate a new random seed.
|
||||
pub fn generate() -> Self {
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut bytes);
|
||||
Seed(bytes)
|
||||
}
|
||||
|
||||
/// Create seed from raw bytes.
|
||||
pub fn from_bytes(bytes: [u8; 32]) -> Self {
|
||||
Seed(bytes)
|
||||
}
|
||||
|
||||
/// Create seed from hex string (64 hex chars).
|
||||
pub fn from_hex(hex_str: &str) -> Result<Self, String> {
|
||||
let bytes = hex::decode(hex_str).map_err(|e| format!("invalid hex: {e}"))?;
|
||||
if bytes.len() != 32 {
|
||||
return Err(format!("expected 32 bytes, got {}", bytes.len()));
|
||||
}
|
||||
let mut seed = [0u8; 32];
|
||||
seed.copy_from_slice(&bytes);
|
||||
Ok(Seed(seed))
|
||||
}
|
||||
|
||||
/// Derive the full identity keypair from this seed.
|
||||
///
|
||||
/// Uses identical HKDF derivation as featherChat:
|
||||
/// - Ed25519: `HKDF(seed, salt=None, info="warzone-ed25519")`
|
||||
/// - X25519: `HKDF(seed, salt=None, info="warzone-x25519")`
|
||||
pub fn derive_identity(&self) -> IdentityKeyPair {
|
||||
let hk = Hkdf::<Sha256>::new(None, &self.0);
|
||||
|
||||
let mut ed_bytes = [0u8; 32];
|
||||
hk.expand(b"warzone-ed25519", &mut ed_bytes)
|
||||
.expect("HKDF expand for Ed25519");
|
||||
let signing = SigningKey::from_bytes(&ed_bytes);
|
||||
ed_bytes.fill(0);
|
||||
|
||||
let mut x_bytes = [0u8; 32];
|
||||
hk.expand(b"warzone-x25519", &mut x_bytes)
|
||||
.expect("HKDF expand for X25519");
|
||||
let encryption = StaticSecret::from(x_bytes);
|
||||
x_bytes.fill(0);
|
||||
|
||||
IdentityKeyPair {
|
||||
signing,
|
||||
encryption,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to BIP39 mnemonic (24 words).
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::mnemonic::seed_to_mnemonic`
|
||||
pub fn to_mnemonic(&self) -> String {
|
||||
let mnemonic =
|
||||
bip39::Mnemonic::from_entropy(&self.0).expect("32 bytes is valid BIP39 entropy");
|
||||
mnemonic.to_string()
|
||||
}
|
||||
|
||||
/// Recover seed from BIP39 mnemonic (24 words).
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
|
||||
pub fn from_mnemonic(words: &str) -> Result<Self, String> {
|
||||
let mnemonic: bip39::Mnemonic = words.parse().map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||
let entropy = mnemonic.to_entropy();
|
||||
if entropy.len() != 32 {
|
||||
return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));
|
||||
}
|
||||
let mut seed = [0u8; 32];
|
||||
seed.copy_from_slice(&entropy);
|
||||
Ok(Seed(seed))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Seed {
|
||||
fn drop(&mut self) {
|
||||
self.0.fill(0); // zeroize on drop
|
||||
}
|
||||
}
|
||||
|
||||
/// The full identity keypair derived from a seed.
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::identity::IdentityKeyPair`
|
||||
pub struct IdentityKeyPair {
|
||||
pub signing: SigningKey,
|
||||
pub encryption: StaticSecret,
|
||||
}
|
||||
|
||||
impl IdentityKeyPair {
|
||||
/// Get the public identity (safe to share).
|
||||
pub fn public_identity(&self) -> PublicIdentity {
|
||||
let verifying = self.signing.verifying_key();
|
||||
let encryption_pub = x25519_dalek::PublicKey::from(&self.encryption);
|
||||
let fingerprint = Fingerprint::from_verifying_key(&verifying);
|
||||
|
||||
PublicIdentity {
|
||||
signing: verifying,
|
||||
encryption: encryption_pub,
|
||||
fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncated SHA-256 hash of the Ed25519 public key (16 bytes).
|
||||
/// Displayed as `xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx`.
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::types::Fingerprint`
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Fingerprint(pub [u8; 16]);
|
||||
|
||||
impl Fingerprint {
|
||||
pub fn from_verifying_key(key: &VerifyingKey) -> Self {
|
||||
let hash = Sha256::digest(key.as_bytes());
|
||||
let mut fp = [0u8; 16];
|
||||
fp.copy_from_slice(&hash[..16]);
|
||||
Fingerprint(fp)
|
||||
}
|
||||
|
||||
/// Parse from hex string (with or without colons).
|
||||
pub fn from_hex(s: &str) -> Result<Self, String> {
|
||||
let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
|
||||
let bytes = hex::decode(&clean).map_err(|e| format!("invalid hex: {e}"))?;
|
||||
if bytes.len() < 16 {
|
||||
return Err("fingerprint too short".to_string());
|
||||
}
|
||||
let mut fp = [0u8; 16];
|
||||
fp.copy_from_slice(&bytes[..16]);
|
||||
Ok(Fingerprint(fp))
|
||||
}
|
||||
|
||||
/// As raw bytes.
|
||||
pub fn as_bytes(&self) -> &[u8; 16] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// As hex string without colons.
|
||||
pub fn to_hex(&self) -> String {
|
||||
hex::encode(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Fingerprint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}",
|
||||
u16::from_be_bytes([self.0[0], self.0[1]]),
|
||||
u16::from_be_bytes([self.0[2], self.0[3]]),
|
||||
u16::from_be_bytes([self.0[4], self.0[5]]),
|
||||
u16::from_be_bytes([self.0[6], self.0[7]]),
|
||||
u16::from_be_bytes([self.0[8], self.0[9]]),
|
||||
u16::from_be_bytes([self.0[10], self.0[11]]),
|
||||
u16::from_be_bytes([self.0[12], self.0[13]]),
|
||||
u16::from_be_bytes([self.0[14], self.0[15]]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Fingerprint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Fingerprint({})", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// The public portion of an identity — safe to share with anyone.
|
||||
pub struct PublicIdentity {
|
||||
pub signing: VerifyingKey,
|
||||
pub encryption: x25519_dalek::PublicKey,
|
||||
pub fingerprint: Fingerprint,
|
||||
}
|
||||
|
||||
/// Hash a human-readable room/group name into an opaque hex string.
|
||||
/// Used as QUIC SNI to prevent leaking group names to network observers.
|
||||
///
|
||||
/// `hash_room_name("my-group")` → 32 hex chars (16 bytes of SHA-256).
|
||||
///
|
||||
/// Mirrors the convention in featherChat WZP-FC-5:
|
||||
/// `SHA-256("featherchat-group:" + group_name)[:16]`
|
||||
pub fn hash_room_name(group_name: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(b"featherchat-group:");
|
||||
hasher.update(group_name.as_bytes());
|
||||
let hash = hasher.finalize();
|
||||
hex::encode(&hash[..16])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deterministic_derivation() {
|
||||
let seed = Seed::from_bytes([42u8; 32]);
|
||||
let id1 = seed.derive_identity();
|
||||
let id2 = seed.derive_identity();
|
||||
assert_eq!(
|
||||
id1.signing.verifying_key().as_bytes(),
|
||||
id2.signing.verifying_key().as_bytes(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mnemonic_roundtrip() {
|
||||
let seed = Seed::generate();
|
||||
let words = seed.to_mnemonic();
|
||||
let word_count = words.split_whitespace().count();
|
||||
assert_eq!(word_count, 24);
|
||||
let recovered = Seed::from_mnemonic(&words).unwrap();
|
||||
assert_eq!(seed.0, recovered.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_roundtrip() {
|
||||
let seed = Seed::generate();
|
||||
let hex_str = hex::encode(seed.0);
|
||||
let recovered = Seed::from_hex(&hex_str).unwrap();
|
||||
assert_eq!(seed.0, recovered.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_format() {
|
||||
let seed = Seed::generate();
|
||||
let id = seed.derive_identity();
|
||||
let pub_id = id.public_identity();
|
||||
let fp_str = pub_id.fingerprint.to_string();
|
||||
// Format: xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:xxxx
|
||||
assert_eq!(fp_str.len(), 39);
|
||||
assert_eq!(fp_str.chars().filter(|c| *c == ':').count(), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_deterministic() {
|
||||
let h1 = hash_room_name("my-group");
|
||||
let h2 = hash_room_name("my-group");
|
||||
assert_eq!(h1, h2);
|
||||
assert_eq!(h1.len(), 32); // 16 bytes = 32 hex chars
|
||||
assert!(h1.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_different_inputs() {
|
||||
assert_ne!(hash_room_name("alpha"), hash_room_name("beta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_handshake_derivation() {
|
||||
use wzp_proto::KeyExchange;
|
||||
// Verify identity module matches the KeyExchange trait implementation
|
||||
let seed = [99u8; 32];
|
||||
let id = Seed::from_bytes(seed).derive_identity();
|
||||
let kx = crate::WarzoneKeyExchange::from_identity_seed(&seed);
|
||||
|
||||
assert_eq!(
|
||||
id.signing.verifying_key().as_bytes(),
|
||||
&kx.identity_public_key(),
|
||||
);
|
||||
assert_eq!(
|
||||
id.public_identity().fingerprint.as_bytes(),
|
||||
&kx.fingerprint(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,14 @@
|
||||
|
||||
pub mod anti_replay;
|
||||
pub mod handshake;
|
||||
pub mod identity;
|
||||
pub mod nonce;
|
||||
pub mod rekey;
|
||||
pub mod session;
|
||||
|
||||
pub use anti_replay::AntiReplayWindow;
|
||||
pub use handshake::WarzoneKeyExchange;
|
||||
pub use identity::{hash_room_name, Fingerprint, IdentityKeyPair, PublicIdentity, Seed};
|
||||
pub use nonce::{build_nonce, Direction};
|
||||
pub use rekey::RekeyManager;
|
||||
pub use session::ChaChaSession;
|
||||
|
||||
571
crates/wzp-crypto/tests/featherchat_compat.rs
Normal file
571
crates/wzp-crypto/tests/featherchat_compat.rs
Normal file
@@ -0,0 +1,571 @@
|
||||
//! Cross-project compatibility tests between WZP and featherChat.
|
||||
//!
|
||||
//! Verifies:
|
||||
//! 1. Identity: same seed → same keys → same fingerprints (WZP-FC-8)
|
||||
//! 2. CallSignal: WZP SignalMessage serializes into FC CallSignal.payload correctly
|
||||
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
|
||||
//! 4. Mnemonic: BIP39 interop between both implementations
|
||||
|
||||
use wzp_proto::KeyExchange;
|
||||
|
||||
// ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn same_seed_same_ed25519_key() {
|
||||
let seed = [42u8; 32];
|
||||
|
||||
let wzp_kx = wzp_crypto::WarzoneKeyExchange::from_identity_seed(&seed);
|
||||
let wzp_pub = wzp_kx.identity_public_key();
|
||||
|
||||
let fc_seed = warzone_protocol::identity::Seed::from_bytes(seed);
|
||||
let fc_id = fc_seed.derive_identity();
|
||||
let fc_pub = fc_id.signing.verifying_key();
|
||||
|
||||
assert_eq!(&wzp_pub, fc_pub.as_bytes(), "Ed25519 keys must match");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_seed_same_fingerprint() {
|
||||
let seed = [99u8; 32];
|
||||
|
||||
let wzp_kx = wzp_crypto::WarzoneKeyExchange::from_identity_seed(&seed);
|
||||
let wzp_fp = wzp_kx.fingerprint();
|
||||
|
||||
let fc_seed = warzone_protocol::identity::Seed::from_bytes(seed);
|
||||
let fc_fp = fc_seed.derive_identity().public_identity().fingerprint.0;
|
||||
|
||||
assert_eq!(wzp_fp, fc_fp, "Fingerprints must match");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wzp_identity_module_matches_featherchat() {
|
||||
let seed = [0xAB; 32];
|
||||
|
||||
let wzp_pub = wzp_crypto::Seed::from_bytes(seed)
|
||||
.derive_identity()
|
||||
.public_identity();
|
||||
|
||||
let fc_pub = warzone_protocol::identity::Seed::from_bytes(seed)
|
||||
.derive_identity()
|
||||
.public_identity();
|
||||
|
||||
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
|
||||
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
|
||||
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
|
||||
assert_eq!(wzp_pub.fingerprint.to_string(), fc_pub.fingerprint.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_seed_identity_match() {
|
||||
let fc_seed = warzone_protocol::identity::Seed::generate();
|
||||
let raw = fc_seed.0;
|
||||
|
||||
let fc_fp = fc_seed.derive_identity().public_identity().fingerprint.0;
|
||||
let wzp_fp = wzp_crypto::WarzoneKeyExchange::from_identity_seed(&raw).fingerprint();
|
||||
|
||||
assert_eq!(wzp_fp, fc_fp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hkdf_derive_matches() {
|
||||
let seed = [0x55; 32];
|
||||
|
||||
let fc_ed = warzone_protocol::crypto::hkdf_derive(&seed, b"", b"warzone-ed25519", 32);
|
||||
let fc_signing = ed25519_dalek::SigningKey::from_bytes(&fc_ed.try_into().unwrap());
|
||||
let fc_pub = fc_signing.verifying_key();
|
||||
|
||||
let wzp_pub = wzp_crypto::WarzoneKeyExchange::from_identity_seed(&seed).identity_public_key();
|
||||
|
||||
assert_eq!(&wzp_pub, fc_pub.as_bytes());
|
||||
}
|
||||
|
||||
// ─── BIP39 Mnemonic Interop ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn mnemonic_roundtrip_fc_to_wzp() {
|
||||
let seed = [0x77; 32];
|
||||
let fc_mnemonic = warzone_protocol::identity::Seed::from_bytes(seed).to_mnemonic();
|
||||
let wzp_recovered = wzp_crypto::Seed::from_mnemonic(&fc_mnemonic).unwrap();
|
||||
assert_eq!(wzp_recovered.0, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mnemonic_roundtrip_wzp_to_fc() {
|
||||
let seed = [0x33; 32];
|
||||
let wzp_mnemonic = wzp_crypto::Seed::from_bytes(seed).to_mnemonic();
|
||||
let fc_recovered = warzone_protocol::identity::Seed::from_mnemonic(&wzp_mnemonic).unwrap();
|
||||
assert_eq!(fc_recovered.0, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mnemonic_strings_identical() {
|
||||
let seed = [0xDE; 32];
|
||||
let fc_words = warzone_protocol::identity::Seed::from_bytes(seed).to_mnemonic();
|
||||
let wzp_words = wzp_crypto::Seed::from_bytes(seed).to_mnemonic();
|
||||
assert_eq!(fc_words, wzp_words);
|
||||
}
|
||||
|
||||
// ─── CallSignal Payload Interop ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
// WZP creates a CallOffer SignalMessage
|
||||
let offer = wzp_proto::SignalMessage::CallOffer {
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
};
|
||||
|
||||
// Encode as featherChat CallSignal payload
|
||||
let payload = wzp_client::featherchat::encode_call_payload(
|
||||
&offer,
|
||||
Some("relay.example.com:4433"),
|
||||
Some("myroom"),
|
||||
);
|
||||
|
||||
// Verify it's valid JSON
|
||||
let parsed: serde_json::Value = serde_json::from_str(&payload).unwrap();
|
||||
assert!(parsed.get("signal").is_some());
|
||||
assert_eq!(parsed["relay_addr"], "relay.example.com:4433");
|
||||
assert_eq!(parsed["room"], "myroom");
|
||||
|
||||
// featherChat would put this in WireMessage::CallSignal { payload, ... }
|
||||
// Verify the FC side can create a CallSignal with this payload
|
||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||
id: "call-123".to_string(),
|
||||
sender_fingerprint: "abcd1234".to_string(),
|
||||
signal_type: warzone_protocol::message::CallSignalType::Offer,
|
||||
payload: payload.clone(),
|
||||
target: "peer-fingerprint".to_string(),
|
||||
};
|
||||
|
||||
// Verify it serializes with bincode (FC's wire format)
|
||||
let encoded = bincode::serialize(&fc_msg).unwrap();
|
||||
assert!(!encoded.is_empty());
|
||||
|
||||
// And deserializes back
|
||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
|
||||
if let warzone_protocol::message::WireMessage::CallSignal {
|
||||
id, payload: p, signal_type, ..
|
||||
} = decoded
|
||||
{
|
||||
assert_eq!(id, "call-123");
|
||||
assert!(matches!(signal_type, warzone_protocol::message::CallSignalType::Offer));
|
||||
|
||||
// Decode the WZP payload back
|
||||
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
|
||||
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
|
||||
assert!(matches!(wzp_payload.signal, wzp_proto::SignalMessage::CallOffer { .. }));
|
||||
} else {
|
||||
panic!("expected CallSignal");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||
let answer = wzp_proto::SignalMessage::CallAnswer {
|
||||
identity_pub: [10u8; 32],
|
||||
ephemeral_pub: [20u8; 32],
|
||||
signature: vec![30u8; 64],
|
||||
chosen_profile: wzp_proto::QualityProfile::DEGRADED,
|
||||
};
|
||||
|
||||
let payload = wzp_client::featherchat::encode_call_payload(&answer, None, None);
|
||||
|
||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||
id: "call-456".to_string(),
|
||||
sender_fingerprint: "efgh5678".to_string(),
|
||||
signal_type: warzone_protocol::message::CallSignalType::Answer,
|
||||
payload,
|
||||
target: "caller-fp".to_string(),
|
||||
};
|
||||
|
||||
let bytes = bincode::serialize(&fc_msg).unwrap();
|
||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&bytes).unwrap();
|
||||
|
||||
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
||||
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
||||
if let wzp_proto::SignalMessage::CallAnswer { chosen_profile, .. } = wzp.signal {
|
||||
assert_eq!(chosen_profile.codec, wzp_proto::CodecId::Opus6k);
|
||||
} else {
|
||||
panic!("expected CallAnswer");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
};
|
||||
|
||||
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
|
||||
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
|
||||
assert!(matches!(signal_type, wzp_client::featherchat::CallSignalType::Hangup));
|
||||
|
||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||
id: "call-789".to_string(),
|
||||
sender_fingerprint: "xyz".to_string(),
|
||||
signal_type: warzone_protocol::message::CallSignalType::Hangup,
|
||||
payload,
|
||||
target: "peer".to_string(),
|
||||
};
|
||||
|
||||
let bytes = bincode::serialize(&fc_msg).unwrap();
|
||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&bytes).unwrap();
|
||||
|
||||
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
||||
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
||||
assert!(matches!(wzp.signal, wzp_proto::SignalMessage::Hangup { .. }));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Auth Token Contract ────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn auth_validate_request_matches_fc_contract() {
|
||||
// WZP sends: { "token": "..." }
|
||||
// FC expects: ValidateRequest { token: String }
|
||||
let wzp_request = serde_json::json!({ "token": "test-token-123" });
|
||||
let json_str = wzp_request.to_string();
|
||||
|
||||
// FC can deserialize this (same shape as their ValidateRequest)
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FcValidateRequest {
|
||||
token: String,
|
||||
}
|
||||
let fc_req: FcValidateRequest = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(fc_req.token, "test-token-123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_validate_response_matches_wzp_expectations() {
|
||||
// FC returns: { "valid": true, "fingerprint": "...", "alias": "..." }
|
||||
// WZP expects: wzp_relay::auth::ValidateResponse
|
||||
let fc_response = serde_json::json!({
|
||||
"valid": true,
|
||||
"fingerprint": "a3f8:1b2c:3d4e:5f60:7182:93a4:b5c6:d7e8",
|
||||
"alias": "manwe",
|
||||
"eth_address": null
|
||||
});
|
||||
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
assert!(wzp_resp.valid);
|
||||
assert_eq!(
|
||||
wzp_resp.fingerprint.unwrap(),
|
||||
"a3f8:1b2c:3d4e:5f60:7182:93a4:b5c6:d7e8"
|
||||
);
|
||||
assert_eq!(wzp_resp.alias.unwrap(), "manwe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_invalid_response_matches() {
|
||||
let fc_response = serde_json::json!({ "valid": false });
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
assert!(!wzp_resp.valid);
|
||||
assert!(wzp_resp.fingerprint.is_none());
|
||||
}
|
||||
|
||||
// ─── Signal Type Mapping ────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn all_signal_types_map_correctly() {
|
||||
use wzp_client::featherchat::{signal_to_call_type, CallSignalType};
|
||||
|
||||
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
|
||||
(
|
||||
wzp_proto::SignalMessage::CallOffer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
signature: vec![], supported_profiles: vec![],
|
||||
},
|
||||
"Offer",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::CallAnswer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
chosen_profile: wzp_proto::QualityProfile::GOOD,
|
||||
},
|
||||
"Answer",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::IceCandidate {
|
||||
candidate: "candidate:1".to_string(),
|
||||
},
|
||||
"IceCandidate",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::Hangup {
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
},
|
||||
"Hangup",
|
||||
),
|
||||
];
|
||||
|
||||
for (signal, expected_name) in cases {
|
||||
let ct = signal_to_call_type(&signal);
|
||||
let name = format!("{ct:?}");
|
||||
assert_eq!(name, expected_name, "signal type mapping for {expected_name}");
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Room Hashing + Access Control ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_deterministic() {
|
||||
let h1 = wzp_crypto::hash_room_name("ops-channel");
|
||||
let h2 = wzp_crypto::hash_room_name("ops-channel");
|
||||
assert_eq!(h1, h2, "same input must produce same hash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_is_32_hex_chars() {
|
||||
let h = wzp_crypto::hash_room_name("test-room");
|
||||
assert_eq!(h.len(), 32, "hash must be 32 hex chars (16 bytes)");
|
||||
assert!(
|
||||
h.chars().all(|c| c.is_ascii_hexdigit()),
|
||||
"hash must contain only hex characters, got: {h}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_different_inputs() {
|
||||
let h1 = wzp_crypto::hash_room_name("alpha");
|
||||
let h2 = wzp_crypto::hash_room_name("beta");
|
||||
let h3 = wzp_crypto::hash_room_name("alpha-2");
|
||||
assert_ne!(h1, h2, "different names must produce different hashes");
|
||||
assert_ne!(h1, h3);
|
||||
assert_ne!(h2, h3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_room_name_matches_fc_convention() {
|
||||
// Manual SHA-256("featherchat-group:" + name)[:16] using the sha2 crate directly
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
let name = "warzone-squad";
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(b"featherchat-group:");
|
||||
hasher.update(name.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let expected = hex::encode(&digest[..16]);
|
||||
|
||||
let actual = wzp_crypto::hash_room_name(name);
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"hash_room_name must equal SHA-256('featherchat-group:' + name)[:16]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_acl_open_mode() {
|
||||
let mgr = wzp_relay::room::RoomManager::new();
|
||||
// Open mode: everyone is authorized regardless of fingerprint presence
|
||||
assert!(mgr.is_authorized("any-room", None));
|
||||
assert!(mgr.is_authorized("any-room", Some("random-fp")));
|
||||
assert!(mgr.is_authorized("another-room", Some("abc:def")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_acl_enforced() {
|
||||
let mgr = wzp_relay::room::RoomManager::with_acl();
|
||||
// ACL enabled but no fingerprint provided => denied
|
||||
assert!(
|
||||
!mgr.is_authorized("room1", None),
|
||||
"ACL mode must reject connections without a fingerprint"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_acl_allows_listed() {
|
||||
let mut mgr = wzp_relay::room::RoomManager::with_acl();
|
||||
mgr.allow("secure-room", "alice-fp");
|
||||
mgr.allow("secure-room", "bob-fp");
|
||||
|
||||
assert!(mgr.is_authorized("secure-room", Some("alice-fp")));
|
||||
assert!(mgr.is_authorized("secure-room", Some("bob-fp")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_acl_denies_unlisted() {
|
||||
let mut mgr = wzp_relay::room::RoomManager::with_acl();
|
||||
mgr.allow("secure-room", "alice-fp");
|
||||
|
||||
assert!(
|
||||
!mgr.is_authorized("secure-room", Some("eve-fp")),
|
||||
"unlisted fingerprints must be denied"
|
||||
);
|
||||
assert!(
|
||||
!mgr.is_authorized("secure-room", Some("mallory-fp")),
|
||||
"unlisted fingerprints must be denied"
|
||||
);
|
||||
// No fingerprint at all => also denied
|
||||
assert!(
|
||||
!mgr.is_authorized("secure-room", None),
|
||||
"no fingerprint must be denied in ACL mode"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Web Bridge Auth + Proto Standalone + S-9 ──────────────────────────────
|
||||
|
||||
/// WZP-S-6: featherChat may include `eth_address` in ValidateResponse.
|
||||
/// WZP's ValidateResponse must handle it gracefully (serde ignores unknown fields).
|
||||
#[test]
|
||||
fn auth_response_with_eth_address() {
|
||||
// FC response with eth_address present (non-null)
|
||||
let with_eth = serde_json::json!({
|
||||
"valid": true,
|
||||
"fingerprint": "a1b2:c3d4:e5f6:7890:abcd:ef01:2345:6789",
|
||||
"alias": "vitalik",
|
||||
"eth_address": "0x1234567890abcdef1234567890abcdef12345678"
|
||||
});
|
||||
let resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_eth).unwrap();
|
||||
assert!(resp.valid);
|
||||
assert_eq!(
|
||||
resp.fingerprint.unwrap(),
|
||||
"a1b2:c3d4:e5f6:7890:abcd:ef01:2345:6789"
|
||||
);
|
||||
assert_eq!(resp.alias.unwrap(), "vitalik");
|
||||
|
||||
// FC response with eth_address = null
|
||||
let with_null_eth = serde_json::json!({
|
||||
"valid": true,
|
||||
"fingerprint": "dead:beef:cafe:babe:1234:5678:9abc:def0",
|
||||
"alias": "anon",
|
||||
"eth_address": null
|
||||
});
|
||||
let resp2: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_null_eth).unwrap();
|
||||
assert!(resp2.valid);
|
||||
assert_eq!(
|
||||
resp2.fingerprint.unwrap(),
|
||||
"dead:beef:cafe:babe:1234:5678:9abc:def0"
|
||||
);
|
||||
|
||||
// FC response without eth_address at all
|
||||
let without_eth = serde_json::json!({
|
||||
"valid": false
|
||||
});
|
||||
let resp3: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(without_eth).unwrap();
|
||||
assert!(!resp3.valid);
|
||||
}
|
||||
|
||||
/// WZP-S-7: SignalMessage::AuthToken { token } exists and round-trips via serde.
|
||||
#[test]
|
||||
fn wzp_proto_has_auth_token_variant() {
|
||||
let msg = wzp_proto::SignalMessage::AuthToken {
|
||||
token: "fc-bearer-token-xyz".to_string(),
|
||||
};
|
||||
|
||||
// Serialize to JSON
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("AuthToken"));
|
||||
assert!(json.contains("fc-bearer-token-xyz"));
|
||||
|
||||
// Deserialize back
|
||||
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
if let wzp_proto::SignalMessage::AuthToken { token } = decoded {
|
||||
assert_eq!(token, "fc-bearer-token-xyz");
|
||||
} else {
|
||||
panic!("expected AuthToken variant, got: {decoded:?}");
|
||||
}
|
||||
}
|
||||
|
||||
/// WZP-S-6: WZP CallSignalType has all variants matching featherChat's set.
|
||||
#[test]
|
||||
fn all_fc_call_signal_types_representable() {
|
||||
use wzp_client::featherchat::CallSignalType;
|
||||
|
||||
// Verify each FC variant can be constructed and debug-printed
|
||||
let variants: Vec<(CallSignalType, &str)> = vec![
|
||||
(CallSignalType::Offer, "Offer"),
|
||||
(CallSignalType::Answer, "Answer"),
|
||||
(CallSignalType::IceCandidate, "IceCandidate"),
|
||||
(CallSignalType::Hangup, "Hangup"),
|
||||
(CallSignalType::Reject, "Reject"),
|
||||
(CallSignalType::Ringing, "Ringing"),
|
||||
(CallSignalType::Busy, "Busy"),
|
||||
];
|
||||
|
||||
assert_eq!(variants.len(), 7, "featherChat defines exactly 7 call signal types");
|
||||
|
||||
for (variant, expected_name) in &variants {
|
||||
let name = format!("{variant:?}");
|
||||
assert_eq!(&name, expected_name);
|
||||
|
||||
// Each variant should serialize/deserialize cleanly
|
||||
let json = serde_json::to_string(variant).unwrap();
|
||||
let round_tripped: CallSignalType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(format!("{round_tripped:?}"), *expected_name);
|
||||
}
|
||||
}
|
||||
|
||||
/// WZP-S-9: hashed room name used as QUIC SNI must be valid — lowercase hex only.
|
||||
#[test]
|
||||
fn hash_room_name_used_as_sni_is_valid() {
|
||||
let long_name = "x".repeat(1000);
|
||||
let test_rooms = [
|
||||
"general",
|
||||
"Voice Room #1",
|
||||
"café-lounge",
|
||||
"a]b[c{d}e",
|
||||
"\u{1f480}\u{1f525}",
|
||||
long_name.as_str(),
|
||||
];
|
||||
|
||||
for room in &test_rooms {
|
||||
let hashed = wzp_crypto::hash_room_name(room);
|
||||
|
||||
// Must be non-empty
|
||||
assert!(!hashed.is_empty(), "hash of '{room}' must not be empty");
|
||||
|
||||
// Must contain only lowercase hex chars (valid for SNI)
|
||||
for ch in hashed.chars() {
|
||||
assert!(
|
||||
ch.is_ascii_hexdigit() && !ch.is_ascii_uppercase(),
|
||||
"hash of '{room}' contains invalid SNI char: '{ch}' (full: {hashed})"
|
||||
);
|
||||
}
|
||||
|
||||
// SHA-256 truncated to 16 bytes -> 32 hex chars
|
||||
assert_eq!(
|
||||
hashed.len(),
|
||||
32,
|
||||
"hash should be 32 hex chars (16 bytes), got {} for '{room}'",
|
||||
hashed.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// WZP-S-7: wzp-proto Cargo.toml must be standalone — no `.workspace = true` inheritance.
|
||||
#[test]
|
||||
fn wzp_proto_cargo_toml_is_standalone() {
|
||||
// Try both paths (run from workspace root or from crate directory)
|
||||
let candidates = [
|
||||
"crates/wzp-proto/Cargo.toml",
|
||||
"../wzp-proto/Cargo.toml",
|
||||
];
|
||||
|
||||
let contents = candidates
|
||||
.iter()
|
||||
.find_map(|p| std::fs::read_to_string(p).ok())
|
||||
.expect("could not read crates/wzp-proto/Cargo.toml from any expected path");
|
||||
|
||||
// Must NOT contain ".workspace = true" anywhere — that would break standalone use
|
||||
assert!(
|
||||
!contents.contains(".workspace = true"),
|
||||
"wzp-proto Cargo.toml must not use workspace inheritance (.workspace = true), \
|
||||
found in:\n{contents}"
|
||||
);
|
||||
|
||||
// Sanity: it should still be a valid Cargo.toml with the right package name
|
||||
assert!(
|
||||
contents.contains("name = \"wzp-proto\""),
|
||||
"expected package name 'wzp-proto' in Cargo.toml"
|
||||
);
|
||||
}
|
||||
@@ -1,17 +1,22 @@
|
||||
[package]
|
||||
name = "wzp-proto"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "MIT OR Apache-2.0"
|
||||
rust-version = "1.85"
|
||||
description = "WarzonePhone protocol types, traits, and core logic"
|
||||
|
||||
# This crate is designed to be importable standalone — no workspace inheritance.
|
||||
# featherChat and other projects can depend on it directly via git:
|
||||
# wzp-proto = { git = "ssh://git@git.manko.yoga:222/manawenuz/wz-phone.git", path = "crates/wzp-proto" }
|
||||
|
||||
[dependencies]
|
||||
bytes = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
bytes = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
tracing = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde_json = "1"
|
||||
|
||||
454
crates/wzp-proto/src/bandwidth.rs
Normal file
454
crates/wzp-proto/src/bandwidth.rs
Normal file
@@ -0,0 +1,454 @@
|
||||
//! GCC-style bandwidth estimation and congestion control.
|
||||
//!
|
||||
//! Tracks available bandwidth using delay-based and loss-based signals,
|
||||
//! then adjusts the sending bitrate to avoid congestion. The estimator
|
||||
//! uses multiplicative decrease (15%) on congestion and additive increase
|
||||
//! (5%) during underuse, following the general shape of Google Congestion
|
||||
//! Control (GCC).
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::packet::QualityReport;
|
||||
use crate::QualityProfile;
|
||||
|
||||
/// Network congestion state derived from delay and loss signals.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum CongestionState {
|
||||
/// Network is fine, can increase bandwidth.
|
||||
Underuse,
|
||||
/// Normal operation.
|
||||
Normal,
|
||||
/// Congestion detected, should decrease bandwidth.
|
||||
Overuse,
|
||||
}
|
||||
|
||||
/// Detects congestion from increasing RTT using an exponential moving average.
|
||||
///
|
||||
/// Maintains a baseline RTT (minimum observed) and compares the smoothed RTT
|
||||
/// against it. If `rtt_ema > baseline * threshold_ratio`, congestion is detected.
|
||||
/// The baseline slowly drifts upward to handle route changes.
|
||||
struct DelayBasedDetector {
|
||||
/// Baseline RTT (minimum observed).
|
||||
baseline_rtt_ms: f64,
|
||||
/// EMA of recent RTT.
|
||||
rtt_ema: f64,
|
||||
/// EMA smoothing factor.
|
||||
alpha: f64,
|
||||
/// Threshold: if rtt_ema > baseline * threshold_ratio, congestion detected.
|
||||
threshold_ratio: f64,
|
||||
/// Current state.
|
||||
state: CongestionState,
|
||||
/// Whether we have received any RTT sample yet.
|
||||
initialized: bool,
|
||||
/// Drift factor: baseline slowly increases each update to track route changes.
|
||||
baseline_drift: f64,
|
||||
}
|
||||
|
||||
impl DelayBasedDetector {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
baseline_rtt_ms: f64::MAX,
|
||||
rtt_ema: 0.0,
|
||||
alpha: 0.3,
|
||||
threshold_ratio: 1.5,
|
||||
state: CongestionState::Normal,
|
||||
initialized: false,
|
||||
baseline_drift: 0.001,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the detector with a new RTT sample.
|
||||
fn update(&mut self, rtt_ms: f64) {
|
||||
if !self.initialized {
|
||||
self.baseline_rtt_ms = rtt_ms;
|
||||
self.rtt_ema = rtt_ms;
|
||||
self.initialized = true;
|
||||
self.state = CongestionState::Normal;
|
||||
return;
|
||||
}
|
||||
|
||||
// Track minimum RTT as baseline.
|
||||
if rtt_ms < self.baseline_rtt_ms {
|
||||
self.baseline_rtt_ms = rtt_ms;
|
||||
} else {
|
||||
// Slowly drift baseline upward to handle route changes.
|
||||
self.baseline_rtt_ms += self.baseline_drift * (rtt_ms - self.baseline_rtt_ms);
|
||||
}
|
||||
|
||||
// Update EMA.
|
||||
self.rtt_ema = self.alpha * rtt_ms + (1.0 - self.alpha) * self.rtt_ema;
|
||||
|
||||
// Determine state.
|
||||
let overuse_threshold = self.baseline_rtt_ms * self.threshold_ratio;
|
||||
let underuse_threshold = self.baseline_rtt_ms * 1.1;
|
||||
|
||||
if self.rtt_ema > overuse_threshold {
|
||||
self.state = CongestionState::Overuse;
|
||||
} else if self.rtt_ema < underuse_threshold {
|
||||
self.state = CongestionState::Underuse;
|
||||
} else {
|
||||
self.state = CongestionState::Normal;
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> CongestionState {
|
||||
self.state
|
||||
}
|
||||
}
|
||||
|
||||
/// Detects congestion from packet loss using a sliding window average.
|
||||
struct LossBasedDetector {
|
||||
/// Recent loss percentages (sliding window).
|
||||
loss_window: VecDeque<f64>,
|
||||
/// Maximum window size.
|
||||
window_size: usize,
|
||||
/// Loss threshold for congestion (default 5%).
|
||||
threshold_pct: f64,
|
||||
}
|
||||
|
||||
impl LossBasedDetector {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
loss_window: VecDeque::with_capacity(10),
|
||||
window_size: 10,
|
||||
threshold_pct: 5.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a loss percentage sample to the window.
|
||||
fn update(&mut self, loss_pct: f64) {
|
||||
if self.loss_window.len() >= self.window_size {
|
||||
self.loss_window.pop_front();
|
||||
}
|
||||
self.loss_window.push_back(loss_pct);
|
||||
}
|
||||
|
||||
/// Returns true if the average loss in the window exceeds the threshold.
|
||||
fn is_congested(&self) -> bool {
|
||||
if self.loss_window.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let avg = self.loss_window.iter().sum::<f64>() / self.loss_window.len() as f64;
|
||||
avg > self.threshold_pct
|
||||
}
|
||||
}
|
||||
|
||||
// ─── BandwidthEstimator ─────────────────────────────────────────────────────
|
||||
|
||||
/// GCC-style bandwidth estimator that tracks available bandwidth using
|
||||
/// delay-based and loss-based congestion signals.
|
||||
///
|
||||
/// # Algorithm
|
||||
///
|
||||
/// - **Overuse** (delay or loss): multiplicative decrease by 15%.
|
||||
/// - **Underuse** (delay) with no loss congestion: additive increase by 5%.
|
||||
/// - **Normal**: hold steady.
|
||||
/// - Result is always clamped to `[min_bw_kbps, max_bw_kbps]`.
|
||||
pub struct BandwidthEstimator {
|
||||
/// Current estimated bandwidth in kbps.
|
||||
estimated_bw_kbps: f64,
|
||||
/// Minimum bandwidth floor (don't go below this).
|
||||
min_bw_kbps: f64,
|
||||
/// Maximum bandwidth ceiling.
|
||||
max_bw_kbps: f64,
|
||||
/// Delay-based detector state.
|
||||
delay_detector: DelayBasedDetector,
|
||||
/// Loss-based detector state.
|
||||
loss_detector: LossBasedDetector,
|
||||
/// Last update timestamp.
|
||||
last_update: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Multiplicative decrease factor applied on congestion (15% reduction).
|
||||
const DECREASE_FACTOR: f64 = 0.85;
|
||||
/// Additive increase factor applied during underuse (5% of current estimate).
|
||||
const INCREASE_FACTOR: f64 = 0.05;
|
||||
|
||||
impl BandwidthEstimator {
|
||||
/// Create a new bandwidth estimator.
|
||||
///
|
||||
/// - `initial_bw_kbps`: starting bandwidth estimate.
|
||||
/// - `min`: minimum bandwidth floor in kbps.
|
||||
/// - `max`: maximum bandwidth ceiling in kbps.
|
||||
pub fn new(initial_bw_kbps: f64, min: f64, max: f64) -> Self {
|
||||
Self {
|
||||
estimated_bw_kbps: initial_bw_kbps,
|
||||
min_bw_kbps: min,
|
||||
max_bw_kbps: max,
|
||||
delay_detector: DelayBasedDetector::new(),
|
||||
loss_detector: LossBasedDetector::new(),
|
||||
last_update: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the estimator with new network observations.
|
||||
///
|
||||
/// Returns the new estimated bandwidth in kbps.
|
||||
///
|
||||
/// - If delay overuse OR loss congested: decrease by 15% (multiplicative decrease).
|
||||
/// - If delay underuse AND not loss congested: increase by 5% (additive increase).
|
||||
/// - If normal: hold steady.
|
||||
/// - Result is clamped to `[min, max]`.
|
||||
pub fn update(&mut self, rtt_ms: f64, loss_pct: f64, _jitter_ms: f64) -> f64 {
|
||||
self.delay_detector.update(rtt_ms);
|
||||
self.loss_detector.update(loss_pct);
|
||||
self.last_update = Some(Instant::now());
|
||||
|
||||
let delay_state = self.delay_detector.state();
|
||||
let loss_congested = self.loss_detector.is_congested();
|
||||
|
||||
if delay_state == CongestionState::Overuse || loss_congested {
|
||||
// Multiplicative decrease.
|
||||
self.estimated_bw_kbps *= DECREASE_FACTOR;
|
||||
} else if delay_state == CongestionState::Underuse && !loss_congested {
|
||||
// Additive increase.
|
||||
self.estimated_bw_kbps += self.estimated_bw_kbps * INCREASE_FACTOR;
|
||||
}
|
||||
// Normal: hold steady — no change.
|
||||
|
||||
// Clamp to [min, max].
|
||||
self.estimated_bw_kbps = self
|
||||
.estimated_bw_kbps
|
||||
.clamp(self.min_bw_kbps, self.max_bw_kbps);
|
||||
|
||||
self.estimated_bw_kbps
|
||||
}
|
||||
|
||||
/// Current estimated bandwidth in kbps.
|
||||
pub fn estimated_kbps(&self) -> f64 {
|
||||
self.estimated_bw_kbps
|
||||
}
|
||||
|
||||
/// Current congestion state (derived from delay detector).
|
||||
pub fn congestion_state(&self) -> CongestionState {
|
||||
self.delay_detector.state()
|
||||
}
|
||||
|
||||
/// Convenience method: update from a `QualityReport`.
|
||||
///
|
||||
/// Extracts RTT, loss, and jitter from the report and feeds them into
|
||||
/// the estimator.
|
||||
pub fn from_quality_report(&mut self, report: &QualityReport) -> f64 {
|
||||
let rtt_ms = report.rtt_ms() as f64;
|
||||
let loss_pct = report.loss_percent() as f64;
|
||||
let jitter_ms = report.jitter_ms as f64;
|
||||
self.update(rtt_ms, loss_pct, jitter_ms)
|
||||
}
|
||||
|
||||
/// Recommend a `QualityProfile` based on the current bandwidth estimate.
|
||||
///
|
||||
/// - bw >= 25 kbps -> GOOD (Opus 24k + 20% FEC = ~28.8 kbps total)
|
||||
/// - bw >= 8 kbps -> DEGRADED (Opus 6k + 50% FEC = ~9.0 kbps)
|
||||
/// - bw < 8 kbps -> CATASTROPHIC (Codec2 1.2k + 100% FEC = ~2.4 kbps)
|
||||
pub fn recommended_profile(&self) -> QualityProfile {
|
||||
if self.estimated_bw_kbps >= 25.0 {
|
||||
QualityProfile::GOOD
|
||||
} else if self.estimated_bw_kbps >= 8.0 {
|
||||
QualityProfile::DEGRADED
|
||||
} else {
|
||||
QualityProfile::CATASTROPHIC
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn initial_bandwidth() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
assert!((bwe.estimated_kbps() - 50.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stable_network_holds_bandwidth() {
|
||||
let mut bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// Feed stable, low RTT and 0% loss — after initial sample sets baseline,
|
||||
// subsequent identical RTT should be underuse (rtt_ema < baseline * 1.1),
|
||||
// causing slow increases. The bandwidth should stay near initial or grow slightly.
|
||||
let initial = bwe.estimated_kbps();
|
||||
for _ in 0..20 {
|
||||
bwe.update(30.0, 0.0, 5.0);
|
||||
}
|
||||
// Should not have decreased significantly.
|
||||
assert!(
|
||||
bwe.estimated_kbps() >= initial,
|
||||
"bandwidth should not decrease on stable network: got {} vs initial {}",
|
||||
bwe.estimated_kbps(),
|
||||
initial
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_rtt_decreases_bandwidth() {
|
||||
let mut bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// Establish a low baseline.
|
||||
for _ in 0..5 {
|
||||
bwe.update(20.0, 0.0, 2.0);
|
||||
}
|
||||
let before = bwe.estimated_kbps();
|
||||
|
||||
// Now feed high RTT to trigger overuse.
|
||||
for _ in 0..10 {
|
||||
bwe.update(200.0, 0.0, 10.0);
|
||||
}
|
||||
assert!(
|
||||
bwe.estimated_kbps() < before,
|
||||
"bandwidth should decrease on high RTT: got {} vs before {}",
|
||||
bwe.estimated_kbps(),
|
||||
before
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_loss_decreases_bandwidth() {
|
||||
let mut bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
let before = bwe.estimated_kbps();
|
||||
|
||||
// Feed 10% loss repeatedly (above the 5% threshold).
|
||||
for _ in 0..15 {
|
||||
bwe.update(20.0, 10.0, 2.0);
|
||||
}
|
||||
assert!(
|
||||
bwe.estimated_kbps() < before,
|
||||
"bandwidth should decrease on high loss: got {} vs before {}",
|
||||
bwe.estimated_kbps(),
|
||||
before
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recovery_increases_bandwidth() {
|
||||
let mut bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
|
||||
// Drive bandwidth down with high RTT.
|
||||
for _ in 0..5 {
|
||||
bwe.update(20.0, 0.0, 2.0);
|
||||
}
|
||||
for _ in 0..20 {
|
||||
bwe.update(200.0, 0.0, 10.0);
|
||||
}
|
||||
let low_bw = bwe.estimated_kbps();
|
||||
assert!(low_bw < 50.0, "should have decreased");
|
||||
|
||||
// Now feed good conditions — low RTT should be underuse, causing increase.
|
||||
// Reset the baseline by feeding very low RTT.
|
||||
for _ in 0..30 {
|
||||
bwe.update(10.0, 0.0, 1.0);
|
||||
}
|
||||
assert!(
|
||||
bwe.estimated_kbps() > low_bw,
|
||||
"bandwidth should recover: got {} vs low {}",
|
||||
bwe.estimated_kbps(),
|
||||
low_bw
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bandwidth_clamped_to_min() {
|
||||
let mut bwe = BandwidthEstimator::new(10.0, 5.0, 100.0);
|
||||
// Keep feeding congestion to drive bandwidth down.
|
||||
for _ in 0..5 {
|
||||
bwe.update(20.0, 0.0, 2.0);
|
||||
}
|
||||
for _ in 0..100 {
|
||||
bwe.update(500.0, 50.0, 100.0);
|
||||
}
|
||||
assert!(
|
||||
(bwe.estimated_kbps() - 5.0).abs() < f64::EPSILON,
|
||||
"bandwidth should be clamped to min: got {}",
|
||||
bwe.estimated_kbps()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bandwidth_clamped_to_max() {
|
||||
let mut bwe = BandwidthEstimator::new(90.0, 2.0, 100.0);
|
||||
// Keep feeding great conditions to drive bandwidth up.
|
||||
for _ in 0..200 {
|
||||
bwe.update(5.0, 0.0, 1.0);
|
||||
}
|
||||
assert!(
|
||||
bwe.estimated_kbps() <= 100.0,
|
||||
"bandwidth should be clamped to max: got {}",
|
||||
bwe.estimated_kbps()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recommended_profile_thresholds() {
|
||||
// At boundary: >= 25 kbps => GOOD
|
||||
let bwe_good = BandwidthEstimator::new(25.0, 2.0, 100.0);
|
||||
assert_eq!(bwe_good.recommended_profile(), QualityProfile::GOOD);
|
||||
|
||||
// Just below 25 => DEGRADED
|
||||
let bwe_degraded = BandwidthEstimator::new(24.9, 2.0, 100.0);
|
||||
assert_eq!(bwe_degraded.recommended_profile(), QualityProfile::DEGRADED);
|
||||
|
||||
// At boundary: >= 8 kbps => DEGRADED
|
||||
let bwe_degraded2 = BandwidthEstimator::new(8.0, 2.0, 100.0);
|
||||
assert_eq!(
|
||||
bwe_degraded2.recommended_profile(),
|
||||
QualityProfile::DEGRADED
|
||||
);
|
||||
|
||||
// Below 8 => CATASTROPHIC
|
||||
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
|
||||
assert_eq!(
|
||||
bwe_cat.recommended_profile(),
|
||||
QualityProfile::CATASTROPHIC
|
||||
);
|
||||
|
||||
// High bandwidth
|
||||
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
|
||||
assert_eq!(bwe_high.recommended_profile(), QualityProfile::GOOD);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_quality_report_integration() {
|
||||
let mut bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
|
||||
// Build a QualityReport with moderate loss and RTT.
|
||||
let report = QualityReport {
|
||||
loss_pct: (10.0_f32 / 100.0 * 255.0) as u8, // ~10% loss
|
||||
rtt_4ms: 25, // 100ms RTT
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
};
|
||||
|
||||
let new_bw = bwe.from_quality_report(&report);
|
||||
// Should return a valid bandwidth value.
|
||||
assert!(new_bw > 0.0);
|
||||
assert!(new_bw <= 100.0);
|
||||
// The estimator should have been updated.
|
||||
assert!((bwe.estimated_kbps() - new_bw).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Additional detector unit tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn delay_detector_starts_normal() {
|
||||
let det = DelayBasedDetector::new();
|
||||
assert_eq!(det.state(), CongestionState::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_detector_below_threshold() {
|
||||
let mut det = LossBasedDetector::new();
|
||||
for _ in 0..10 {
|
||||
det.update(2.0); // 2% loss, well below 5% threshold
|
||||
}
|
||||
assert!(!det.is_congested());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_detector_above_threshold() {
|
||||
let mut det = LossBasedDetector::new();
|
||||
for _ in 0..10 {
|
||||
det.update(8.0); // 8% loss, above 5% threshold
|
||||
}
|
||||
assert!(det.is_congested());
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,8 @@ pub enum CodecId {
|
||||
Codec2_3200 = 3,
|
||||
/// Codec2 at 1200bps (catastrophic conditions)
|
||||
Codec2_1200 = 4,
|
||||
/// Comfort noise descriptor (silence suppression)
|
||||
ComfortNoise = 5,
|
||||
}
|
||||
|
||||
impl CodecId {
|
||||
@@ -27,6 +29,7 @@ impl CodecId {
|
||||
Self::Opus6k => 6_000,
|
||||
Self::Codec2_3200 => 3_200,
|
||||
Self::Codec2_1200 => 1_200,
|
||||
Self::ComfortNoise => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +41,7 @@ impl CodecId {
|
||||
Self::Opus6k => 40,
|
||||
Self::Codec2_3200 => 20,
|
||||
Self::Codec2_1200 => 40,
|
||||
Self::ComfortNoise => 20,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +50,7 @@ impl CodecId {
|
||||
match self {
|
||||
Self::Opus24k | Self::Opus16k | Self::Opus6k => 48_000,
|
||||
Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
|
||||
Self::ComfortNoise => 48_000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +62,7 @@ impl CodecId {
|
||||
2 => Some(Self::Opus6k),
|
||||
3 => Some(Self::Codec2_3200),
|
||||
4 => Some(Self::Codec2_1200),
|
||||
5 => Some(Self::ComfortNoise),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,161 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::packet::MediaPacket;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Adaptive playout delay (NetEq-inspired)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Adaptive playout delay estimator based on observed inter-arrival jitter.
|
||||
///
|
||||
/// Inspired by WebRTC NetEq and IAX2 adaptive jitter buffering. Tracks an
|
||||
/// exponential moving average (EMA) of inter-packet arrival jitter and
|
||||
/// converts it to a target buffer depth in packets.
|
||||
pub struct AdaptivePlayoutDelay {
|
||||
/// Current target delay in packets (equivalent to target_depth).
|
||||
target_delay: usize,
|
||||
/// Minimum allowed delay.
|
||||
min_delay: usize,
|
||||
/// Maximum allowed delay.
|
||||
max_delay: usize,
|
||||
/// Exponential moving average of inter-packet arrival jitter (ms).
|
||||
jitter_ema: f64,
|
||||
/// EMA smoothing factor for jitter increases (fast reaction).
|
||||
alpha_up: f64,
|
||||
/// EMA smoothing factor for jitter decreases (slow decay).
|
||||
alpha_down: f64,
|
||||
/// Last packet arrival timestamp (for computing inter-arrival jitter).
|
||||
last_arrival_ms: Option<u64>,
|
||||
/// Last packet expected timestamp.
|
||||
last_expected_ms: Option<u64>,
|
||||
/// Safety margin added to jitter-derived target (in packets).
|
||||
safety_margin: f64,
|
||||
/// Instant when a jitter spike was detected (handoff detection).
|
||||
spike_detected_at: Option<Instant>,
|
||||
/// Duration to hold max_delay after a spike is detected.
|
||||
spike_cooldown: Duration,
|
||||
/// Multiplier of jitter_ema that constitutes a spike.
|
||||
spike_threshold_multiplier: f64,
|
||||
}
|
||||
|
||||
/// Frame duration in milliseconds (20ms Opus/Codec2 frames).
|
||||
const FRAME_DURATION_MS: f64 = 20.0;
|
||||
/// Default safety margin in packets.
|
||||
const DEFAULT_SAFETY_MARGIN: f64 = 2.0;
|
||||
/// Default EMA smoothing factor (used for both up/down in non-mobile mode).
|
||||
const DEFAULT_ALPHA: f64 = 0.05;
|
||||
|
||||
impl AdaptivePlayoutDelay {
|
||||
/// Create a new adaptive playout delay estimator.
|
||||
///
|
||||
/// - `min_delay`: minimum target delay in packets
|
||||
/// - `max_delay`: maximum target delay in packets
|
||||
pub fn new(min_delay: usize, max_delay: usize) -> Self {
|
||||
Self {
|
||||
target_delay: min_delay,
|
||||
min_delay,
|
||||
max_delay,
|
||||
jitter_ema: 0.0,
|
||||
alpha_up: DEFAULT_ALPHA,
|
||||
alpha_down: DEFAULT_ALPHA,
|
||||
last_arrival_ms: None,
|
||||
last_expected_ms: None,
|
||||
safety_margin: DEFAULT_SAFETY_MARGIN,
|
||||
spike_detected_at: None,
|
||||
spike_cooldown: Duration::from_secs(2),
|
||||
spike_threshold_multiplier: 3.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update with a new packet arrival. Returns the new target delay.
|
||||
///
|
||||
/// - `arrival_ms`: when the packet actually arrived (wall clock)
|
||||
/// - `expected_ms`: when it should have arrived (based on sequence * 20ms)
|
||||
pub fn update(&mut self, arrival_ms: u64, expected_ms: u64) -> usize {
|
||||
if let (Some(last_arrival), Some(last_expected)) =
|
||||
(self.last_arrival_ms, self.last_expected_ms)
|
||||
{
|
||||
let actual_delta = arrival_ms as f64 - last_arrival as f64;
|
||||
let expected_delta = expected_ms as f64 - last_expected as f64;
|
||||
let jitter = (actual_delta - expected_delta).abs();
|
||||
|
||||
// Spike detection: check before EMA update
|
||||
if self.jitter_ema > 0.0
|
||||
&& jitter > self.jitter_ema * self.spike_threshold_multiplier
|
||||
{
|
||||
self.spike_detected_at = Some(Instant::now());
|
||||
}
|
||||
|
||||
// Asymmetric EMA update
|
||||
let alpha = if jitter > self.jitter_ema {
|
||||
self.alpha_up
|
||||
} else {
|
||||
self.alpha_down
|
||||
};
|
||||
self.jitter_ema = alpha * jitter + (1.0 - alpha) * self.jitter_ema;
|
||||
|
||||
// Check if spike cooldown has expired
|
||||
if let Some(spike_time) = self.spike_detected_at {
|
||||
if spike_time.elapsed() >= self.spike_cooldown {
|
||||
self.spike_detected_at = None;
|
||||
}
|
||||
}
|
||||
|
||||
// If within spike cooldown, return max_delay
|
||||
if self.spike_detected_at.is_some() {
|
||||
self.target_delay = self.max_delay;
|
||||
} else {
|
||||
// Convert jitter estimate to target delay in packets
|
||||
let raw_target =
|
||||
(self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||
self.target_delay =
|
||||
(raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||
}
|
||||
}
|
||||
|
||||
self.last_arrival_ms = Some(arrival_ms);
|
||||
self.last_expected_ms = Some(expected_ms);
|
||||
self.target_delay
|
||||
}
|
||||
|
||||
/// Get current target delay in packets.
|
||||
pub fn target_delay(&self) -> usize {
|
||||
self.target_delay
|
||||
}
|
||||
|
||||
/// Get current jitter estimate in ms.
|
||||
pub fn jitter_estimate_ms(&self) -> f64 {
|
||||
self.jitter_ema
|
||||
}
|
||||
|
||||
/// Enable or disable mobile mode, adjusting parameters for cellular networks.
|
||||
///
|
||||
/// Mobile mode uses:
|
||||
/// - Asymmetric alpha (fast up=0.3, slow down=0.02) for quicker spike detection
|
||||
/// - Higher safety margin (3.0 packets) to absorb handoff jitter
|
||||
/// - Spike detection with 2-second cooldown at 3x threshold
|
||||
pub fn set_mobile_mode(&mut self, enabled: bool) {
|
||||
if enabled {
|
||||
self.safety_margin = 3.0;
|
||||
self.alpha_up = 0.3;
|
||||
self.alpha_down = 0.02;
|
||||
self.spike_threshold_multiplier = 3.0;
|
||||
self.spike_cooldown = Duration::from_secs(2);
|
||||
} else {
|
||||
self.safety_margin = DEFAULT_SAFETY_MARGIN;
|
||||
self.alpha_up = DEFAULT_ALPHA;
|
||||
self.alpha_down = DEFAULT_ALPHA;
|
||||
self.spike_threshold_multiplier = 3.0;
|
||||
self.spike_cooldown = Duration::from_secs(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Jitter buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Adaptive jitter buffer that reorders packets by sequence number.
|
||||
///
|
||||
/// Designed for the lossy relay link with up to 5 seconds of buffering depth.
|
||||
@@ -21,6 +175,8 @@ pub struct JitterBuffer {
|
||||
initialized: bool,
|
||||
/// Statistics.
|
||||
stats: JitterStats,
|
||||
/// Optional adaptive playout delay estimator.
|
||||
adaptive: Option<AdaptivePlayoutDelay>,
|
||||
}
|
||||
|
||||
/// Jitter buffer statistics.
|
||||
@@ -32,6 +188,14 @@ pub struct JitterStats {
|
||||
pub packets_late: u64,
|
||||
pub packets_duplicate: u64,
|
||||
pub current_depth: usize,
|
||||
/// Total frames decoded by the consumer (tracked externally via `record_decode`).
|
||||
pub total_decoded: u64,
|
||||
/// Number of times the consumer tried to decode but the buffer was empty/not-ready.
|
||||
pub underruns: u64,
|
||||
/// Number of packets dropped because the buffer exceeded max depth.
|
||||
pub overruns: u64,
|
||||
/// High water mark — maximum buffer depth observed.
|
||||
pub max_depth_seen: usize,
|
||||
}
|
||||
|
||||
/// Result of attempting to get the next packet for playout.
|
||||
@@ -60,6 +224,27 @@ impl JitterBuffer {
|
||||
min_depth,
|
||||
initialized: false,
|
||||
stats: JitterStats::default(),
|
||||
adaptive: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a jitter buffer with adaptive playout delay.
|
||||
///
|
||||
/// The target depth will be automatically adjusted based on observed
|
||||
/// inter-arrival jitter (NetEq-inspired algorithm).
|
||||
///
|
||||
/// - `min_delay`: minimum target delay in packets
|
||||
/// - `max_delay`: maximum target delay in packets (also used as max_depth)
|
||||
pub fn new_adaptive(min_delay: usize, max_delay: usize) -> Self {
|
||||
Self {
|
||||
buffer: BTreeMap::new(),
|
||||
next_playout_seq: 0,
|
||||
max_depth: max_delay,
|
||||
target_depth: min_delay,
|
||||
min_depth: min_delay,
|
||||
initialized: false,
|
||||
stats: JitterStats::default(),
|
||||
adaptive: Some(AdaptivePlayoutDelay::new(min_delay, max_delay)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,12 +284,35 @@ impl JitterBuffer {
|
||||
self.next_playout_seq = seq;
|
||||
}
|
||||
|
||||
// Update adaptive playout delay if enabled.
|
||||
// Use the packet's timestamp as expected_ms and compute a simple wall-clock
|
||||
// proxy from the header timestamp (arrival_ms is approximated as timestamp
|
||||
// + observed jitter, but since we don't have real wall-clock here we use
|
||||
// the receive order with the header timestamp as the expected baseline).
|
||||
if let Some(ref mut adaptive) = self.adaptive {
|
||||
// expected_ms derived from sequence-implied timing: seq * frame_duration
|
||||
let expected_ms = packet.header.timestamp as u64;
|
||||
// For arrival_ms, use the actual receive timestamp. In the absence of
|
||||
// a wall-clock parameter, we use std::time for a monotonic approximation.
|
||||
// However, to keep the API simple, we compute arrival from the packet
|
||||
// stats: the Nth received packet "arrives" at N * frame_duration as a
|
||||
// baseline, and real network jitter shows in the deviation.
|
||||
// NOTE: In production, the caller should pass real wall-clock time.
|
||||
// For now, we use the header timestamp as-is (callers with adaptive
|
||||
// mode should feed arrival time via push_with_arrival).
|
||||
let arrival_ms = expected_ms; // no-op for basic push; use push_with_arrival
|
||||
adaptive.update(arrival_ms, expected_ms);
|
||||
self.target_depth = adaptive.target_delay();
|
||||
self.min_depth = self.min_depth.min(self.target_depth);
|
||||
}
|
||||
|
||||
self.buffer.insert(seq, packet);
|
||||
|
||||
// Evict oldest if over max depth
|
||||
while self.buffer.len() > self.max_depth {
|
||||
if let Some((&oldest_seq, _)) = self.buffer.first_key_value() {
|
||||
self.buffer.remove(&oldest_seq);
|
||||
self.stats.overruns += 1;
|
||||
// Advance playout seq past evicted packet
|
||||
if seq_before(self.next_playout_seq, oldest_seq.wrapping_add(1)) {
|
||||
self.next_playout_seq = oldest_seq.wrapping_add(1);
|
||||
@@ -114,6 +322,9 @@ impl JitterBuffer {
|
||||
}
|
||||
|
||||
self.stats.current_depth = self.buffer.len();
|
||||
if self.stats.current_depth > self.stats.max_depth_seen {
|
||||
self.stats.max_depth_seen = self.stats.current_depth;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the next packet for playout.
|
||||
@@ -163,6 +374,91 @@ impl JitterBuffer {
|
||||
self.stats = JitterStats::default();
|
||||
}
|
||||
|
||||
/// Record that the consumer attempted to decode but the buffer was empty/not-ready.
|
||||
pub fn record_underrun(&mut self) {
|
||||
self.stats.underruns += 1;
|
||||
}
|
||||
|
||||
/// Record a successful frame decode by the consumer.
|
||||
pub fn record_decode(&mut self) {
|
||||
self.stats.total_decoded += 1;
|
||||
}
|
||||
|
||||
/// Reset statistics counters (preserves buffer contents and playout state).
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.stats = JitterStats {
|
||||
current_depth: self.buffer.len(),
|
||||
..JitterStats::default()
|
||||
};
|
||||
}
|
||||
|
||||
/// Push a received packet with an explicit wall-clock arrival time.
|
||||
///
|
||||
/// This is the preferred entry point when adaptive playout delay is enabled,
|
||||
/// since the estimator needs real arrival timestamps.
|
||||
pub fn push_with_arrival(&mut self, packet: MediaPacket, arrival_ms: u64) {
|
||||
let expected_ms = packet.header.timestamp as u64;
|
||||
let seq = packet.header.seq;
|
||||
self.stats.packets_received += 1;
|
||||
|
||||
if !self.initialized {
|
||||
self.next_playout_seq = seq;
|
||||
self.initialized = true;
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if self.buffer.contains_key(&seq) {
|
||||
self.stats.packets_duplicate += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if packet is too old (already played out)
|
||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||
self.stats.packets_late += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
// If we haven't started playout yet, adjust next_playout_seq to earliest known
|
||||
if self.stats.packets_played == 0 && seq_before(seq, self.next_playout_seq) {
|
||||
self.next_playout_seq = seq;
|
||||
}
|
||||
|
||||
// Update adaptive playout delay if enabled.
|
||||
if let Some(ref mut adaptive) = self.adaptive {
|
||||
adaptive.update(arrival_ms, expected_ms);
|
||||
self.target_depth = adaptive.target_delay();
|
||||
}
|
||||
|
||||
self.buffer.insert(seq, packet);
|
||||
|
||||
// Evict oldest if over max depth
|
||||
while self.buffer.len() > self.max_depth {
|
||||
if let Some((&oldest_seq, _)) = self.buffer.first_key_value() {
|
||||
self.buffer.remove(&oldest_seq);
|
||||
self.stats.overruns += 1;
|
||||
if seq_before(self.next_playout_seq, oldest_seq.wrapping_add(1)) {
|
||||
self.next_playout_seq = oldest_seq.wrapping_add(1);
|
||||
self.stats.packets_lost += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.stats.current_depth = self.buffer.len();
|
||||
if self.stats.current_depth > self.stats.max_depth_seen {
|
||||
self.stats.max_depth_seen = self.stats.current_depth;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the adaptive playout delay estimator, if enabled.
|
||||
pub fn adaptive_delay(&self) -> Option<&AdaptivePlayoutDelay> {
|
||||
self.adaptive.as_ref()
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the adaptive playout delay estimator.
|
||||
pub fn adaptive_delay_mut(&mut self) -> Option<&mut AdaptivePlayoutDelay> {
|
||||
self.adaptive.as_mut()
|
||||
}
|
||||
|
||||
/// Adjust target depth based on observed jitter.
|
||||
pub fn set_target_depth(&mut self, depth: usize) {
|
||||
self.target_depth = depth.min(self.max_depth);
|
||||
@@ -304,4 +600,217 @@ mod tests {
|
||||
other => panic!("expected packet 0, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// AdaptivePlayoutDelay tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn adaptive_delay_stable() {
|
||||
// Feed packets with consistent 20ms spacing — target should stay at minimum.
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 50);
|
||||
|
||||
for i in 0u64..200 {
|
||||
let arrival_ms = i * 20;
|
||||
let expected_ms = i * 20;
|
||||
apd.update(arrival_ms, expected_ms);
|
||||
}
|
||||
|
||||
// With zero jitter, target should be min_delay (ceil(0/20) + 2 = 2,
|
||||
// clamped to min_delay=3).
|
||||
assert_eq!(apd.target_delay(), 3);
|
||||
assert!(
|
||||
apd.jitter_estimate_ms() < 1.0,
|
||||
"jitter estimate should be near zero, got {}",
|
||||
apd.jitter_estimate_ms()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_delay_increases_on_jitter() {
|
||||
// Feed packets with variable spacing (±10ms jitter).
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 50);
|
||||
|
||||
// Alternate: arrive 10ms early / 10ms late
|
||||
for i in 0u64..200 {
|
||||
let expected_ms = i * 20;
|
||||
let jitter_offset: i64 = if i % 2 == 0 { 10 } else { -10 };
|
||||
let arrival_ms = (expected_ms as i64 + jitter_offset).max(0) as u64;
|
||||
apd.update(arrival_ms, expected_ms);
|
||||
}
|
||||
|
||||
// Inter-arrival jitter should be ~20ms (swing of 10 to -10 = delta 20).
|
||||
// target = ceil(~20/20) + 2 = 3, but EMA converges near 20 so target >= 3.
|
||||
assert!(
|
||||
apd.target_delay() >= 3,
|
||||
"target should increase with jitter, got {}",
|
||||
apd.target_delay()
|
||||
);
|
||||
assert!(
|
||||
apd.jitter_estimate_ms() > 5.0,
|
||||
"jitter estimate should be significant, got {}",
|
||||
apd.jitter_estimate_ms()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_delay_decreases_on_recovery() {
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 50);
|
||||
|
||||
// Phase 1: high jitter (±30ms)
|
||||
for i in 0u64..200 {
|
||||
let expected_ms = i * 20;
|
||||
let offset: i64 = if i % 2 == 0 { 30 } else { -30 };
|
||||
let arrival_ms = (expected_ms as i64 + offset).max(0) as u64;
|
||||
apd.update(arrival_ms, expected_ms);
|
||||
}
|
||||
let high_target = apd.target_delay();
|
||||
let high_jitter = apd.jitter_estimate_ms();
|
||||
|
||||
// Phase 2: stable (no jitter) — target should decrease via EMA decay
|
||||
for i in 200u64..600 {
|
||||
let t = i * 20;
|
||||
apd.update(t, t);
|
||||
}
|
||||
let low_target = apd.target_delay();
|
||||
let low_jitter = apd.jitter_estimate_ms();
|
||||
|
||||
assert!(
|
||||
low_target <= high_target,
|
||||
"target should decrease after recovery: {} -> {}",
|
||||
high_target,
|
||||
low_target
|
||||
);
|
||||
assert!(
|
||||
low_jitter < high_jitter,
|
||||
"jitter estimate should decrease: {} -> {}",
|
||||
high_jitter,
|
||||
low_jitter
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_delay_clamped() {
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 10);
|
||||
|
||||
// Extreme jitter: packets arrive with huge variance
|
||||
for i in 0u64..500 {
|
||||
let expected_ms = i * 20;
|
||||
let offset: i64 = if i % 2 == 0 { 500 } else { -500 };
|
||||
let arrival_ms = (expected_ms as i64 + offset).max(0) as u64;
|
||||
apd.update(arrival_ms, expected_ms);
|
||||
}
|
||||
|
||||
assert!(
|
||||
apd.target_delay() <= 10,
|
||||
"target should not exceed max_delay=10, got {}",
|
||||
apd.target_delay()
|
||||
);
|
||||
assert!(
|
||||
apd.target_delay() >= 3,
|
||||
"target should not go below min_delay=3, got {}",
|
||||
apd.target_delay()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_jitter_estimate() {
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 50);
|
||||
|
||||
// Initial jitter estimate should be zero
|
||||
assert_eq!(apd.jitter_estimate_ms(), 0.0);
|
||||
|
||||
// After one packet, still zero (no delta yet)
|
||||
apd.update(0, 0);
|
||||
assert_eq!(apd.jitter_estimate_ms(), 0.0);
|
||||
|
||||
// Second packet with 5ms jitter
|
||||
apd.update(25, 20); // arrived 5ms late
|
||||
assert!(
|
||||
apd.jitter_estimate_ms() > 0.0,
|
||||
"jitter estimate should be positive after jittery packet"
|
||||
);
|
||||
assert!(
|
||||
apd.jitter_estimate_ms() <= 5.0,
|
||||
"first jitter sample of 5ms with alpha=0.05 should not exceed 5ms, got {}",
|
||||
apd.jitter_estimate_ms()
|
||||
);
|
||||
|
||||
// Feed many packets with ~15ms jitter — EMA should converge
|
||||
for i in 2u64..500 {
|
||||
let expected_ms = i * 20;
|
||||
let arrival_ms = expected_ms + 15; // consistently 15ms late
|
||||
apd.update(arrival_ms, expected_ms);
|
||||
}
|
||||
// Steady-state: inter-arrival jitter = |35 - 20| = 0 actually,
|
||||
// because if every packet is 15ms late, delta_actual = 35-35 = 20,
|
||||
// same as expected. So jitter should converge toward 0.
|
||||
// Let's use variable jitter instead for a better test.
|
||||
let mut apd2 = AdaptivePlayoutDelay::new(3, 50);
|
||||
for i in 0u64..500 {
|
||||
let expected_ms = i * 20;
|
||||
// Alternate 0ms and 15ms late
|
||||
let extra = if i % 2 == 0 { 0 } else { 15 };
|
||||
let arrival_ms = expected_ms + extra;
|
||||
apd2.update(arrival_ms, expected_ms);
|
||||
}
|
||||
let est = apd2.jitter_estimate_ms();
|
||||
assert!(
|
||||
est > 5.0 && est < 20.0,
|
||||
"jitter estimate should converge near 15ms with alternating 0/15ms offsets, got {}",
|
||||
est
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// JitterBuffer with adaptive mode tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn jitter_buffer_adaptive_constructor() {
|
||||
let jb = JitterBuffer::new_adaptive(5, 250);
|
||||
assert!(jb.adaptive_delay().is_some());
|
||||
assert_eq!(jb.adaptive_delay().unwrap().target_delay(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jitter_buffer_adaptive_push_with_arrival() {
|
||||
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
||||
|
||||
// Push packets with consistent timing
|
||||
for i in 0u16..20 {
|
||||
let pkt = make_packet(i);
|
||||
let arrival_ms = i as u64 * 20;
|
||||
jb.push_with_arrival(pkt, arrival_ms);
|
||||
}
|
||||
|
||||
// With zero jitter, target should stay at min
|
||||
let ad = jb.adaptive_delay().unwrap();
|
||||
assert_eq!(ad.target_delay(), 3);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Mobile mode tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn mobile_mode_increases_safety_margin() {
|
||||
let mut apd = AdaptivePlayoutDelay::new(3, 50);
|
||||
apd.set_mobile_mode(true);
|
||||
assert_eq!(apd.safety_margin, 3.0);
|
||||
assert_eq!(apd.alpha_up, 0.3);
|
||||
assert_eq!(apd.alpha_down, 0.02);
|
||||
|
||||
apd.set_mobile_mode(false);
|
||||
assert_eq!(apd.safety_margin, DEFAULT_SAFETY_MARGIN);
|
||||
assert_eq!(apd.alpha_up, DEFAULT_ALPHA);
|
||||
assert_eq!(apd.alpha_down, DEFAULT_ALPHA);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mobile_mode_accessible_via_jitter_buffer() {
|
||||
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
||||
jb.adaptive_delay_mut().unwrap().set_mobile_mode(true);
|
||||
assert_eq!(jb.adaptive_delay().unwrap().safety_margin, 3.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
//! - Identity = 32-byte seed → HKDF → Ed25519 (signing) + X25519 (encryption)
|
||||
//! - Fingerprint = SHA-256(Ed25519 pub)[:16]
|
||||
|
||||
pub mod bandwidth;
|
||||
pub mod codec_id;
|
||||
pub mod error;
|
||||
pub mod jitter;
|
||||
@@ -23,7 +24,11 @@ pub mod traits;
|
||||
// Re-export key types at crate root for convenience.
|
||||
pub use codec_id::{CodecId, QualityProfile};
|
||||
pub use error::*;
|
||||
pub use packet::{HangupReason, MediaHeader, MediaPacket, QualityReport, SignalMessage};
|
||||
pub use quality::{AdaptiveQualityController, Tier};
|
||||
pub use packet::{
|
||||
HangupReason, MediaHeader, MediaPacket, MiniFrameContext, MiniHeader, QualityReport,
|
||||
RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL, FRAME_TYPE_MINI,
|
||||
};
|
||||
pub use bandwidth::{BandwidthEstimator, CongestionState};
|
||||
pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
|
||||
pub use session::{Session, SessionEvent, SessionState};
|
||||
pub use traits::*;
|
||||
|
||||
@@ -46,6 +46,23 @@ impl MediaHeader {
|
||||
/// Header size in bytes on the wire.
|
||||
pub const WIRE_SIZE: usize = 12;
|
||||
|
||||
/// Create a default header for raw PCM relay (used by WebSocket bridge).
|
||||
pub fn default_pcm() -> Self {
|
||||
Self {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode the FEC ratio float (0.0-2.0+) to a 7-bit value (0-127).
|
||||
pub fn encode_fec_ratio(ratio: f32) -> u8 {
|
||||
// Map 0.0-2.0 to 0-127, clamping at 127
|
||||
@@ -191,6 +208,9 @@ pub struct MediaPacket {
|
||||
pub quality_report: Option<QualityReport>,
|
||||
}
|
||||
|
||||
/// Maximum number of mini-frames between full headers (1 second at 50 fps).
|
||||
pub const MINI_FRAME_FULL_INTERVAL: u32 = 50;
|
||||
|
||||
impl MediaPacket {
|
||||
/// Serialize the entire packet to bytes.
|
||||
pub fn to_bytes(&self) -> Bytes {
|
||||
@@ -239,6 +259,276 @@ impl MediaPacket {
|
||||
quality_report,
|
||||
})
|
||||
}
|
||||
|
||||
/// Serialize with mini-frame compression.
|
||||
///
|
||||
/// Uses the `MiniFrameContext` to decide whether to emit a compact 4-byte
|
||||
/// mini-header or a full 12-byte header. A full header is forced on the
|
||||
/// first frame and every `MINI_FRAME_FULL_INTERVAL` frames thereafter.
|
||||
pub fn encode_compact(
|
||||
&self,
|
||||
ctx: &mut MiniFrameContext,
|
||||
frames_since_full: &mut u32,
|
||||
) -> Bytes {
|
||||
if *frames_since_full > 0 && *frames_since_full < MINI_FRAME_FULL_INTERVAL {
|
||||
// --- mini frame ---
|
||||
let ts_delta = self
|
||||
.header
|
||||
.timestamp
|
||||
.wrapping_sub(ctx.last_header.unwrap().timestamp)
|
||||
as u16;
|
||||
let mini = MiniHeader {
|
||||
timestamp_delta_ms: ts_delta,
|
||||
payload_len: self.payload.len() as u16,
|
||||
};
|
||||
let total = 1 + MiniHeader::WIRE_SIZE + self.payload.len();
|
||||
let mut buf = BytesMut::with_capacity(total);
|
||||
buf.put_u8(FRAME_TYPE_MINI);
|
||||
mini.write_to(&mut buf);
|
||||
buf.put(self.payload.clone());
|
||||
// Advance the context so the next mini-frame delta is relative
|
||||
// to this frame, mirroring what expand() does on the decoder side.
|
||||
ctx.update(&self.header);
|
||||
*frames_since_full += 1;
|
||||
buf.freeze()
|
||||
} else {
|
||||
// --- full frame ---
|
||||
let qr_size = if self.quality_report.is_some() {
|
||||
QualityReport::WIRE_SIZE
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total = 1 + MediaHeader::WIRE_SIZE + self.payload.len() + qr_size;
|
||||
let mut buf = BytesMut::with_capacity(total);
|
||||
buf.put_u8(FRAME_TYPE_FULL);
|
||||
self.header.write_to(&mut buf);
|
||||
buf.put(self.payload.clone());
|
||||
if let Some(ref qr) = self.quality_report {
|
||||
qr.write_to(&mut buf);
|
||||
}
|
||||
ctx.update(&self.header);
|
||||
*frames_since_full = 1; // next frame will be the 1st after full
|
||||
buf.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode from compact wire format (auto-detects full vs mini).
|
||||
///
|
||||
/// Returns `None` on malformed input or if a mini-frame arrives before any
|
||||
/// full header baseline has been established.
|
||||
pub fn decode_compact(buf: &[u8], ctx: &mut MiniFrameContext) -> Option<Self> {
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let frame_type = buf[0];
|
||||
let rest = &buf[1..];
|
||||
|
||||
match frame_type {
|
||||
FRAME_TYPE_FULL => {
|
||||
let pkt = Self::from_bytes(Bytes::copy_from_slice(rest))?;
|
||||
ctx.update(&pkt.header);
|
||||
Some(pkt)
|
||||
}
|
||||
FRAME_TYPE_MINI => {
|
||||
if rest.len() < MiniHeader::WIRE_SIZE {
|
||||
return None;
|
||||
}
|
||||
let mut cursor = rest;
|
||||
let mini = MiniHeader::read_from(&mut cursor)?;
|
||||
let payload_start = 1 + MiniHeader::WIRE_SIZE;
|
||||
let payload_end = payload_start + mini.payload_len as usize;
|
||||
if buf.len() < payload_end {
|
||||
return None;
|
||||
}
|
||||
let payload = Bytes::copy_from_slice(&buf[payload_start..payload_end]);
|
||||
let header = ctx.expand(&mini)?;
|
||||
Some(Self {
|
||||
header,
|
||||
payload,
|
||||
quality_report: None,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Trunking — multiplex multiple session packets into one QUIC datagram
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single entry inside a [`TrunkFrame`].
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrunkEntry {
|
||||
/// 2-byte session identifier (up to 65 536 sessions).
|
||||
pub session_id: [u8; 2],
|
||||
/// Encoded MediaPacket payload (already compressed).
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl TrunkEntry {
|
||||
/// Per-entry wire overhead: 2 (session_id) + 2 (len).
|
||||
pub const OVERHEAD: usize = 4;
|
||||
}
|
||||
|
||||
/// A trunked frame carrying multiple session packets in one datagram.
|
||||
///
|
||||
/// Wire format:
|
||||
/// ```text
|
||||
/// [count:u16] [entry1] [entry2] ...
|
||||
/// ```
|
||||
/// Each entry:
|
||||
/// ```text
|
||||
/// [session_id:2] [len:u16] [payload:len]
|
||||
/// ```
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrunkFrame {
|
||||
pub packets: Vec<TrunkEntry>,
|
||||
}
|
||||
|
||||
impl TrunkFrame {
|
||||
/// Create an empty trunk frame.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
packets: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a session packet to the frame.
|
||||
pub fn push(&mut self, session_id: [u8; 2], payload: Bytes) {
|
||||
self.packets.push(TrunkEntry {
|
||||
session_id,
|
||||
payload,
|
||||
});
|
||||
}
|
||||
|
||||
/// Number of entries in the frame.
|
||||
pub fn len(&self) -> usize {
|
||||
self.packets.len()
|
||||
}
|
||||
|
||||
/// Whether the frame is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.packets.is_empty()
|
||||
}
|
||||
|
||||
/// Total wire size of the encoded frame.
|
||||
pub fn wire_size(&self) -> usize {
|
||||
// 2 bytes for count + each entry
|
||||
2 + self
|
||||
.packets
|
||||
.iter()
|
||||
.map(|e| TrunkEntry::OVERHEAD + e.payload.len())
|
||||
.sum::<usize>()
|
||||
}
|
||||
|
||||
/// Encode to wire bytes.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(self.wire_size());
|
||||
buf.put_u16(self.packets.len() as u16);
|
||||
for entry in &self.packets {
|
||||
buf.put_slice(&entry.session_id);
|
||||
buf.put_u16(entry.payload.len() as u16);
|
||||
buf.put(entry.payload.clone());
|
||||
}
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode from wire bytes. Returns `None` on malformed input.
|
||||
pub fn decode(buf: &[u8]) -> Option<Self> {
|
||||
if buf.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let mut cursor = &buf[..];
|
||||
let count = cursor.get_u16() as usize;
|
||||
let mut packets = Vec::with_capacity(count);
|
||||
for _ in 0..count {
|
||||
if cursor.remaining() < TrunkEntry::OVERHEAD {
|
||||
return None;
|
||||
}
|
||||
let mut session_id = [0u8; 2];
|
||||
session_id[0] = cursor.get_u8();
|
||||
session_id[1] = cursor.get_u8();
|
||||
let len = cursor.get_u16() as usize;
|
||||
if cursor.remaining() < len {
|
||||
return None;
|
||||
}
|
||||
let payload = Bytes::copy_from_slice(&cursor[..len]);
|
||||
cursor.advance(len);
|
||||
packets.push(TrunkEntry {
|
||||
session_id,
|
||||
payload,
|
||||
});
|
||||
}
|
||||
Some(Self { packets })
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mini-frames — compact header for steady-state media packets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Frame type tag: full MediaHeader follows.
|
||||
pub const FRAME_TYPE_FULL: u8 = 0x00;
|
||||
/// Frame type tag: MiniHeader follows (requires prior baseline).
|
||||
pub const FRAME_TYPE_MINI: u8 = 0x01;
|
||||
|
||||
/// Compact 4-byte header used after a full MediaHeader baseline has been
|
||||
/// established. Only the timestamp delta and payload length are transmitted;
|
||||
/// all other fields are inherited from the last full header.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct MiniHeader {
|
||||
/// Milliseconds elapsed since the last header's timestamp.
|
||||
pub timestamp_delta_ms: u16,
|
||||
/// Length of the payload that follows this header.
|
||||
pub payload_len: u16,
|
||||
}
|
||||
|
||||
impl MiniHeader {
|
||||
/// Header size in bytes on the wire.
|
||||
pub const WIRE_SIZE: usize = 4;
|
||||
|
||||
/// Serialize to a 4-byte buffer.
|
||||
pub fn write_to(&self, buf: &mut impl BufMut) {
|
||||
buf.put_u16(self.timestamp_delta_ms);
|
||||
buf.put_u16(self.payload_len);
|
||||
}
|
||||
|
||||
/// Deserialize from a buffer. Returns `None` if insufficient data.
|
||||
pub fn read_from(buf: &mut impl Buf) -> Option<Self> {
|
||||
if buf.remaining() < Self::WIRE_SIZE {
|
||||
return None;
|
||||
}
|
||||
Some(Self {
|
||||
timestamp_delta_ms: buf.get_u16(),
|
||||
payload_len: buf.get_u16(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stateful context that expands [`MiniHeader`]s back into full
|
||||
/// [`MediaHeader`]s by tracking the last baseline header.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct MiniFrameContext {
|
||||
last_header: Option<MediaHeader>,
|
||||
}
|
||||
|
||||
impl MiniFrameContext {
|
||||
/// Record a full header as the new baseline for subsequent mini-frames.
|
||||
pub fn update(&mut self, header: &MediaHeader) {
|
||||
self.last_header = Some(*header);
|
||||
}
|
||||
|
||||
/// Expand a mini-header into a full [`MediaHeader`] using the stored
|
||||
/// baseline. Returns `None` if no baseline has been set yet.
|
||||
pub fn expand(&mut self, mini: &MiniHeader) -> Option<MediaHeader> {
|
||||
let base = self.last_header.as_ref()?;
|
||||
let mut expanded = *base;
|
||||
expanded.seq = base.seq.wrapping_add(1);
|
||||
expanded.timestamp = base.timestamp.wrapping_add(mini.timestamp_delta_ms as u32);
|
||||
self.last_header = Some(expanded);
|
||||
Some(expanded)
|
||||
}
|
||||
}
|
||||
|
||||
/// Signaling messages sent over the reliable QUIC stream.
|
||||
@@ -258,6 +548,9 @@ pub enum SignalMessage {
|
||||
signature: Vec<u8>,
|
||||
/// Supported quality profiles.
|
||||
supported_profiles: Vec<crate::QualityProfile>,
|
||||
/// Optional display name set by the caller.
|
||||
#[serde(default)]
|
||||
alias: Option<String>,
|
||||
},
|
||||
|
||||
/// Call acceptance (analogous to Warzone's WireMessage::CallAnswer).
|
||||
@@ -297,6 +590,86 @@ pub enum SignalMessage {
|
||||
|
||||
/// End the call.
|
||||
Hangup { reason: HangupReason },
|
||||
|
||||
/// featherChat bearer token for relay authentication.
|
||||
/// Sent as the first signal message when --auth-url is configured.
|
||||
AuthToken { token: String },
|
||||
|
||||
/// Put the call on hold (stop sending media, keep session alive).
|
||||
Hold,
|
||||
/// Resume a held call.
|
||||
Unhold,
|
||||
/// Mute request from the remote side (server-initiated mute, like IAX2 QUELCH).
|
||||
Mute,
|
||||
/// Unmute request from the remote side (like IAX2 UNQUELCH).
|
||||
Unmute,
|
||||
/// Transfer the call to another peer.
|
||||
Transfer {
|
||||
target_fingerprint: String,
|
||||
/// Optional relay address for the transfer target.
|
||||
relay_addr: Option<String>,
|
||||
},
|
||||
/// Acknowledge a transfer request.
|
||||
TransferAck,
|
||||
|
||||
/// Presence update from a peer relay (gossip protocol).
|
||||
/// Sent periodically over probe connections to share which fingerprints
|
||||
/// are connected to the sending relay.
|
||||
PresenceUpdate {
|
||||
/// Fingerprints currently connected to the sending relay.
|
||||
fingerprints: Vec<String>,
|
||||
/// Address of the sending relay (e.g., "192.168.1.10:4433").
|
||||
relay_addr: String,
|
||||
},
|
||||
|
||||
/// Ask a peer relay to look up a fingerprint in its registry.
|
||||
RouteQuery {
|
||||
fingerprint: String,
|
||||
ttl: u8,
|
||||
},
|
||||
/// Response to a route query.
|
||||
RouteResponse {
|
||||
fingerprint: String,
|
||||
found: bool,
|
||||
relay_chain: Vec<String>,
|
||||
},
|
||||
|
||||
/// Request to set up a forwarding session for a specific fingerprint.
|
||||
/// Sent over a relay link (`_relay` SNI) to ask the peer relay to
|
||||
/// create a room and forward media for the given session.
|
||||
SessionForward {
|
||||
session_id: String,
|
||||
target_fingerprint: String,
|
||||
source_relay: String,
|
||||
},
|
||||
/// Confirm that the forwarding session has been set up on the peer relay.
|
||||
/// The `room_name` tells the source relay which room to address media to.
|
||||
SessionForwardAck {
|
||||
session_id: String,
|
||||
room_name: String,
|
||||
},
|
||||
|
||||
/// Room membership update — sent by relay to all participants when someone joins or leaves.
|
||||
RoomUpdate {
|
||||
/// Current participant count.
|
||||
count: u32,
|
||||
/// List of participants currently in the room.
|
||||
participants: Vec<RoomParticipant>,
|
||||
},
|
||||
|
||||
/// Set or update the client's display name.
|
||||
/// Sent by client after joining; relay updates the participant entry and
|
||||
/// re-broadcasts a RoomUpdate to all participants.
|
||||
SetAlias { alias: String },
|
||||
}
|
||||
|
||||
/// A participant entry in a RoomUpdate message.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RoomParticipant {
|
||||
/// Identity fingerprint (hex string, stable across reconnects if seed is persisted).
|
||||
pub fingerprint: String,
|
||||
/// Optional display name set by the client.
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
/// Reasons for ending a call.
|
||||
@@ -410,6 +783,112 @@ mod tests {
|
||||
assert_eq!(packet.quality_report, decoded.quality_report);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hold_unhold_serialize() {
|
||||
let hold = SignalMessage::Hold;
|
||||
let json = serde_json::to_string(&hold).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(decoded, SignalMessage::Hold));
|
||||
|
||||
let unhold = SignalMessage::Unhold;
|
||||
let json = serde_json::to_string(&unhold).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(decoded, SignalMessage::Unhold));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mute_unmute_serialize() {
|
||||
let mute = SignalMessage::Mute;
|
||||
let json = serde_json::to_string(&mute).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(decoded, SignalMessage::Mute));
|
||||
|
||||
let unmute = SignalMessage::Unmute;
|
||||
let json = serde_json::to_string(&unmute).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(decoded, SignalMessage::Unmute));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transfer_serialize() {
|
||||
let transfer = SignalMessage::Transfer {
|
||||
target_fingerprint: "abc123".to_string(),
|
||||
relay_addr: Some("relay.example.com:4433".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&transfer).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::Transfer {
|
||||
target_fingerprint,
|
||||
relay_addr,
|
||||
} => {
|
||||
assert_eq!(target_fingerprint, "abc123");
|
||||
assert_eq!(relay_addr.unwrap(), "relay.example.com:4433");
|
||||
}
|
||||
_ => panic!("expected Transfer variant"),
|
||||
}
|
||||
|
||||
// Also test with relay_addr = None
|
||||
let transfer_no_relay = SignalMessage::Transfer {
|
||||
target_fingerprint: "def456".to_string(),
|
||||
relay_addr: None,
|
||||
};
|
||||
let json = serde_json::to_string(&transfer_no_relay).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::Transfer {
|
||||
target_fingerprint,
|
||||
relay_addr,
|
||||
} => {
|
||||
assert_eq!(target_fingerprint, "def456");
|
||||
assert!(relay_addr.is_none());
|
||||
}
|
||||
_ => panic!("expected Transfer variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transfer_ack_serialize() {
|
||||
let ack = SignalMessage::TransferAck;
|
||||
let json = serde_json::to_string(&ack).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(decoded, SignalMessage::TransferAck));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn presence_update_signal_roundtrip() {
|
||||
let msg = SignalMessage::PresenceUpdate {
|
||||
fingerprints: vec!["aabb".to_string(), "ccdd".to_string()],
|
||||
relay_addr: "10.0.0.1:4433".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::PresenceUpdate { fingerprints, relay_addr } => {
|
||||
assert_eq!(fingerprints.len(), 2);
|
||||
assert!(fingerprints.contains(&"aabb".to_string()));
|
||||
assert!(fingerprints.contains(&"ccdd".to_string()));
|
||||
assert_eq!(relay_addr, "10.0.0.1:4433");
|
||||
}
|
||||
_ => panic!("expected PresenceUpdate variant"),
|
||||
}
|
||||
|
||||
// Empty fingerprints list
|
||||
let msg_empty = SignalMessage::PresenceUpdate {
|
||||
fingerprints: vec![],
|
||||
relay_addr: "10.0.0.2:4433".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&msg_empty).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::PresenceUpdate { fingerprints, relay_addr } => {
|
||||
assert!(fingerprints.is_empty());
|
||||
assert_eq!(relay_addr, "10.0.0.2:4433");
|
||||
}
|
||||
_ => panic!("expected PresenceUpdate variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fec_ratio_encode_decode() {
|
||||
let ratio = 0.5;
|
||||
@@ -421,4 +900,247 @@ mod tests {
|
||||
let encoded_max = MediaHeader::encode_fec_ratio(ratio_max);
|
||||
assert_eq!(encoded_max, 127);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// TrunkFrame tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn trunk_frame_encode_decode() {
|
||||
let mut frame = TrunkFrame::new();
|
||||
frame.push([0, 1], Bytes::from_static(b"hello"));
|
||||
frame.push([0, 2], Bytes::from_static(b"world!"));
|
||||
frame.push([1, 0], Bytes::from_static(b"x"));
|
||||
assert_eq!(frame.len(), 3);
|
||||
|
||||
let encoded = frame.encode();
|
||||
let decoded = TrunkFrame::decode(&encoded).expect("decode failed");
|
||||
assert_eq!(decoded.len(), 3);
|
||||
assert_eq!(decoded.packets[0].session_id, [0, 1]);
|
||||
assert_eq!(decoded.packets[0].payload, Bytes::from_static(b"hello"));
|
||||
assert_eq!(decoded.packets[1].session_id, [0, 2]);
|
||||
assert_eq!(decoded.packets[1].payload, Bytes::from_static(b"world!"));
|
||||
assert_eq!(decoded.packets[2].session_id, [1, 0]);
|
||||
assert_eq!(decoded.packets[2].payload, Bytes::from_static(b"x"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trunk_frame_empty() {
|
||||
let frame = TrunkFrame::new();
|
||||
assert!(frame.is_empty());
|
||||
assert_eq!(frame.len(), 0);
|
||||
|
||||
let encoded = frame.encode();
|
||||
// Just the 2-byte count header with value 0.
|
||||
assert_eq!(encoded.len(), 2);
|
||||
assert_eq!(&encoded[..], &[0, 0]);
|
||||
|
||||
let decoded = TrunkFrame::decode(&encoded).unwrap();
|
||||
assert!(decoded.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trunk_entry_wire_size() {
|
||||
// Each entry overhead must be exactly 4 bytes (2 session_id + 2 len).
|
||||
assert_eq!(TrunkEntry::OVERHEAD, 4);
|
||||
|
||||
// Verify empirically: one entry with a 10-byte payload should produce
|
||||
// 2 (count) + 4 (overhead) + 10 (payload) = 16 bytes total.
|
||||
let mut frame = TrunkFrame::new();
|
||||
frame.push([0xAB, 0xCD], Bytes::from(vec![0u8; 10]));
|
||||
let encoded = frame.encode();
|
||||
assert_eq!(encoded.len(), 2 + 4 + 10);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// MiniHeader / MiniFrameContext tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn mini_header_encode_decode() {
|
||||
let mini = MiniHeader {
|
||||
timestamp_delta_ms: 20,
|
||||
payload_len: 160,
|
||||
};
|
||||
let mut buf = BytesMut::new();
|
||||
mini.write_to(&mut buf);
|
||||
|
||||
let mut cursor = &buf[..];
|
||||
let decoded = MiniHeader::read_from(&mut cursor).unwrap();
|
||||
assert_eq!(mini, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_header_wire_size() {
|
||||
let mini = MiniHeader {
|
||||
timestamp_delta_ms: 0xFFFF,
|
||||
payload_len: 0xFFFF,
|
||||
};
|
||||
let mut buf = BytesMut::new();
|
||||
mini.write_to(&mut buf);
|
||||
assert_eq!(buf.len(), 4);
|
||||
assert_eq!(MiniHeader::WIRE_SIZE, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_frame_context_expand() {
|
||||
let baseline = MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 10,
|
||||
seq: 100,
|
||||
timestamp: 1000,
|
||||
fec_block: 5,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
};
|
||||
|
||||
let mut ctx = MiniFrameContext::default();
|
||||
ctx.update(&baseline);
|
||||
|
||||
// First expansion
|
||||
let mini1 = MiniHeader {
|
||||
timestamp_delta_ms: 20,
|
||||
payload_len: 80,
|
||||
};
|
||||
let h1 = ctx.expand(&mini1).unwrap();
|
||||
assert_eq!(h1.seq, 101);
|
||||
assert_eq!(h1.timestamp, 1020);
|
||||
assert_eq!(h1.codec_id, CodecId::Opus24k);
|
||||
assert_eq!(h1.fec_block, 5);
|
||||
|
||||
// Second expansion — builds on expanded h1
|
||||
let mini2 = MiniHeader {
|
||||
timestamp_delta_ms: 20,
|
||||
payload_len: 80,
|
||||
};
|
||||
let h2 = ctx.expand(&mini2).unwrap();
|
||||
assert_eq!(h2.seq, 102);
|
||||
assert_eq!(h2.timestamp, 1040);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_frame_context_no_baseline() {
|
||||
let mut ctx = MiniFrameContext::default();
|
||||
let mini = MiniHeader {
|
||||
timestamp_delta_ms: 20,
|
||||
payload_len: 80,
|
||||
};
|
||||
assert!(ctx.expand(&mini).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_vs_mini_size_comparison() {
|
||||
// Full frame on wire: 1 byte type tag + 12 byte MediaHeader = 13
|
||||
let full_size = 1 + MediaHeader::WIRE_SIZE;
|
||||
assert_eq!(full_size, 13);
|
||||
|
||||
// Mini frame on wire: 1 byte type tag + 4 byte MiniHeader = 5
|
||||
let mini_size = 1 + MiniHeader::WIRE_SIZE;
|
||||
assert_eq!(mini_size, 5);
|
||||
|
||||
// Verify the constants match expectations
|
||||
assert_eq!(FRAME_TYPE_FULL, 0x00);
|
||||
assert_eq!(FRAME_TYPE_MINI, 0x01);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// encode_compact / decode_compact tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
fn make_media_packet(seq: u16, ts: u32, payload: &[u8]) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 10,
|
||||
seq,
|
||||
timestamp: ts,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(payload.to_vec()),
|
||||
quality_report: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_frame_encode_decode_sequence() {
|
||||
let mut enc_ctx = MiniFrameContext::default();
|
||||
let mut dec_ctx = MiniFrameContext::default();
|
||||
let mut frames_since_full: u32 = 0;
|
||||
|
||||
let packets: Vec<MediaPacket> = (0..5)
|
||||
.map(|i| make_media_packet(i, i as u32 * 20, b"audio"))
|
||||
.collect();
|
||||
|
||||
for (i, pkt) in packets.iter().enumerate() {
|
||||
let wire = pkt.encode_compact(&mut enc_ctx, &mut frames_since_full);
|
||||
|
||||
if i == 0 {
|
||||
// First frame must be full
|
||||
assert_eq!(wire[0], FRAME_TYPE_FULL, "frame 0 should be FULL");
|
||||
} else {
|
||||
// Subsequent frames should be mini
|
||||
assert_eq!(wire[0], FRAME_TYPE_MINI, "frame {i} should be MINI");
|
||||
// Mini wire: 1 (tag) + 4 (mini header) + payload
|
||||
assert_eq!(wire.len(), 1 + MiniHeader::WIRE_SIZE + pkt.payload.len());
|
||||
}
|
||||
|
||||
let decoded = MediaPacket::decode_compact(&wire, &mut dec_ctx)
|
||||
.unwrap_or_else(|| panic!("decode failed at frame {i}"));
|
||||
assert_eq!(decoded.header.seq, pkt.header.seq);
|
||||
assert_eq!(decoded.header.timestamp, pkt.header.timestamp);
|
||||
assert_eq!(decoded.payload, pkt.payload);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_frame_periodic_full() {
|
||||
let mut ctx = MiniFrameContext::default();
|
||||
let mut frames_since_full: u32 = 0;
|
||||
|
||||
// Encode MINI_FRAME_FULL_INTERVAL + 1 frames. Frame 0 and frame 50
|
||||
// should be FULL, everything in between should be MINI.
|
||||
for i in 0..=MINI_FRAME_FULL_INTERVAL {
|
||||
let pkt = make_media_packet(i as u16, i * 20, b"data");
|
||||
let wire = pkt.encode_compact(&mut ctx, &mut frames_since_full);
|
||||
|
||||
if i == 0 || i == MINI_FRAME_FULL_INTERVAL {
|
||||
assert_eq!(
|
||||
wire[0], FRAME_TYPE_FULL,
|
||||
"frame {i} should be FULL"
|
||||
);
|
||||
} else {
|
||||
assert_eq!(
|
||||
wire[0], FRAME_TYPE_MINI,
|
||||
"frame {i} should be MINI"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mini_frame_disabled() {
|
||||
// Simulate disabled mini-frames by always keeping frames_since_full at 0
|
||||
// (which is what the encoder does when the feature is off).
|
||||
let mut ctx = MiniFrameContext::default();
|
||||
|
||||
for i in 0..10u16 {
|
||||
let pkt = make_media_packet(i, i as u32 * 20, b"payload");
|
||||
// When mini-frames are disabled, the encoder always passes
|
||||
// frames_since_full = 0 equivalent by never using encode_compact.
|
||||
// We test the raw path: frames_since_full forced to 0 every time.
|
||||
let mut frames_since_full: u32 = 0;
|
||||
let wire = pkt.encode_compact(&mut ctx, &mut frames_since_full);
|
||||
assert_eq!(wire[0], FRAME_TYPE_FULL, "frame {i} should be FULL when disabled");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::packet::QualityReport;
|
||||
use crate::traits::QualityController;
|
||||
@@ -24,24 +25,71 @@ impl Tier {
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine which tier a quality report belongs to.
|
||||
/// Determine which tier a quality report belongs to (default/WiFi thresholds).
|
||||
pub fn classify(report: &QualityReport) -> Self {
|
||||
Self::classify_with_context(report, NetworkContext::Unknown)
|
||||
}
|
||||
|
||||
/// Classify with network-context-aware thresholds.
|
||||
pub fn classify_with_context(report: &QualityReport, context: NetworkContext) -> Self {
|
||||
let loss = report.loss_percent();
|
||||
let rtt = report.rtt_ms();
|
||||
|
||||
if loss > 40.0 || rtt > 600 {
|
||||
Self::Catastrophic
|
||||
} else if loss > 10.0 || rtt > 400 {
|
||||
Self::Degraded
|
||||
} else {
|
||||
Self::Good
|
||||
match context {
|
||||
NetworkContext::CellularLte
|
||||
| NetworkContext::Cellular5g
|
||||
| NetworkContext::Cellular3g => {
|
||||
// Tighter thresholds for cellular networks
|
||||
if loss > 25.0 || rtt > 500 {
|
||||
Self::Catastrophic
|
||||
} else if loss > 8.0 || rtt > 300 {
|
||||
Self::Degraded
|
||||
} else {
|
||||
Self::Good
|
||||
}
|
||||
}
|
||||
NetworkContext::WiFi | NetworkContext::Unknown => {
|
||||
// Original thresholds
|
||||
if loss > 40.0 || rtt > 600 {
|
||||
Self::Catastrophic
|
||||
} else if loss > 10.0 || rtt > 400 {
|
||||
Self::Degraded
|
||||
} else {
|
||||
Self::Good
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the next lower (worse) tier, or None if already at the worst.
|
||||
pub fn downgrade(self) -> Option<Tier> {
|
||||
match self {
|
||||
Self::Good => Some(Self::Degraded),
|
||||
Self::Degraded => Some(Self::Catastrophic),
|
||||
Self::Catastrophic => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes the network transport type for context-aware quality decisions.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NetworkContext {
|
||||
WiFi,
|
||||
CellularLte,
|
||||
Cellular5g,
|
||||
Cellular3g,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Default for NetworkContext {
|
||||
fn default() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Adaptive quality controller with hysteresis to prevent tier flapping.
|
||||
///
|
||||
/// - Downgrade: 3 consecutive reports in a worse tier
|
||||
/// - Downgrade: 3 consecutive reports in a worse tier (2 on cellular)
|
||||
/// - Upgrade: 10 consecutive reports in a better tier
|
||||
pub struct AdaptiveQualityController {
|
||||
current_tier: Tier,
|
||||
@@ -54,14 +102,26 @@ pub struct AdaptiveQualityController {
|
||||
history: VecDeque<QualityReport>,
|
||||
/// Whether the profile was manually forced (disables adaptive logic).
|
||||
forced: bool,
|
||||
/// Current network context for threshold selection.
|
||||
network_context: NetworkContext,
|
||||
/// FEC boost expiry time (set during network handoff).
|
||||
fec_boost_until: Option<Instant>,
|
||||
/// FEC boost amount to add during handoff recovery window.
|
||||
fec_boost_amount: f32,
|
||||
}
|
||||
|
||||
/// Threshold for downgrading (fast reaction to degradation).
|
||||
const DOWNGRADE_THRESHOLD: u32 = 3;
|
||||
/// Threshold for downgrading on cellular networks (even faster).
|
||||
const CELLULAR_DOWNGRADE_THRESHOLD: u32 = 2;
|
||||
/// Threshold for upgrading (slow, cautious improvement).
|
||||
const UPGRADE_THRESHOLD: u32 = 10;
|
||||
/// Maximum history window size.
|
||||
const HISTORY_SIZE: usize = 20;
|
||||
/// Default FEC boost amount during handoff recovery.
|
||||
const DEFAULT_FEC_BOOST: f32 = 0.2;
|
||||
/// Duration of FEC boost after a network handoff.
|
||||
const FEC_BOOST_DURATION_SECS: u64 = 10;
|
||||
|
||||
impl AdaptiveQualityController {
|
||||
pub fn new() -> Self {
|
||||
@@ -72,6 +132,9 @@ impl AdaptiveQualityController {
|
||||
consecutive_down: 0,
|
||||
history: VecDeque::with_capacity(HISTORY_SIZE),
|
||||
forced: false,
|
||||
network_context: NetworkContext::default(),
|
||||
fec_boost_until: None,
|
||||
fec_boost_amount: DEFAULT_FEC_BOOST,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +143,69 @@ impl AdaptiveQualityController {
|
||||
self.current_tier
|
||||
}
|
||||
|
||||
/// Get the current network context.
|
||||
pub fn network_context(&self) -> NetworkContext {
|
||||
self.network_context
|
||||
}
|
||||
|
||||
/// Signal a network transport change (e.g., WiFi to cellular handoff).
|
||||
///
|
||||
/// When switching from WiFi to any cellular type, this preemptively
|
||||
/// downgrades one quality tier and activates a temporary FEC boost.
|
||||
pub fn signal_network_change(&mut self, new_context: NetworkContext) {
|
||||
let old = self.network_context;
|
||||
self.network_context = new_context;
|
||||
|
||||
let new_is_cellular = matches!(
|
||||
new_context,
|
||||
NetworkContext::CellularLte | NetworkContext::Cellular5g | NetworkContext::Cellular3g
|
||||
);
|
||||
|
||||
// If switching from WiFi to cellular, preemptively downgrade one tier
|
||||
if old == NetworkContext::WiFi && new_is_cellular {
|
||||
if let Some(lower_tier) = self.current_tier.downgrade() {
|
||||
self.current_tier = lower_tier;
|
||||
self.current_profile = lower_tier.profile();
|
||||
}
|
||||
// Reset counters to avoid stale hysteresis state
|
||||
self.consecutive_up = 0;
|
||||
self.consecutive_down = 0;
|
||||
// Un-force so adaptive logic resumes
|
||||
self.forced = false;
|
||||
}
|
||||
|
||||
// Activate FEC boost for any network change
|
||||
self.fec_boost_until = Some(Instant::now() + Duration::from_secs(FEC_BOOST_DURATION_SECS));
|
||||
}
|
||||
|
||||
/// Returns the FEC boost amount if within the handoff recovery window, 0.0 otherwise.
|
||||
///
|
||||
/// Callers should add this to their base FEC ratio during the boost window.
|
||||
pub fn fec_boost(&self) -> f32 {
|
||||
if let Some(until) = self.fec_boost_until {
|
||||
if Instant::now() < until {
|
||||
return self.fec_boost_amount;
|
||||
}
|
||||
}
|
||||
0.0
|
||||
}
|
||||
|
||||
/// Reset the hysteresis counters.
|
||||
pub fn reset_counters(&mut self) {
|
||||
self.consecutive_up = 0;
|
||||
self.consecutive_down = 0;
|
||||
}
|
||||
|
||||
/// Get the effective downgrade threshold based on network context.
|
||||
fn downgrade_threshold(&self) -> u32 {
|
||||
match self.network_context {
|
||||
NetworkContext::CellularLte
|
||||
| NetworkContext::Cellular5g
|
||||
| NetworkContext::Cellular3g => CELLULAR_DOWNGRADE_THRESHOLD,
|
||||
_ => DOWNGRADE_THRESHOLD,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_transition(&mut self, observed_tier: Tier) -> Option<QualityProfile> {
|
||||
if observed_tier == self.current_tier {
|
||||
self.consecutive_up = 0;
|
||||
@@ -96,7 +222,7 @@ impl AdaptiveQualityController {
|
||||
if is_worse {
|
||||
self.consecutive_up = 0;
|
||||
self.consecutive_down += 1;
|
||||
if self.consecutive_down >= DOWNGRADE_THRESHOLD {
|
||||
if self.consecutive_down >= self.downgrade_threshold() {
|
||||
self.current_tier = observed_tier;
|
||||
self.current_profile = observed_tier.profile();
|
||||
self.consecutive_down = 0;
|
||||
@@ -142,7 +268,7 @@ impl QualityController for AdaptiveQualityController {
|
||||
return None;
|
||||
}
|
||||
|
||||
let observed = Tier::classify(report);
|
||||
let observed = Tier::classify_with_context(report, self.network_context);
|
||||
self.try_transition(observed)
|
||||
}
|
||||
|
||||
@@ -246,4 +372,110 @@ mod tests {
|
||||
assert_eq!(Tier::classify(&make_report(50.0, 200)), Tier::Catastrophic);
|
||||
assert_eq!(Tier::classify(&make_report(5.0, 700)), Tier::Catastrophic);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Network context tests
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn cellular_tighter_thresholds() {
|
||||
// 12% loss: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(12.0, 200);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Degraded
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Degraded
|
||||
);
|
||||
|
||||
// 9% loss: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(9.0, 200);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Good
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Degraded
|
||||
);
|
||||
|
||||
// 30% loss: Degraded on WiFi, Catastrophic on cellular
|
||||
let report = make_report(30.0, 200);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Degraded
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::Cellular3g),
|
||||
Tier::Catastrophic
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cellular_rtt_thresholds() {
|
||||
// RTT 350ms: Good on WiFi, Degraded on cellular
|
||||
let report = make_report(2.0, 348); // rtt_4ms rounds so use 348
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::WiFi),
|
||||
Tier::Good
|
||||
);
|
||||
assert_eq!(
|
||||
Tier::classify_with_context(&report, NetworkContext::CellularLte),
|
||||
Tier::Degraded
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cellular_faster_downgrade() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
ctrl.signal_network_change(NetworkContext::CellularLte);
|
||||
// Reset tier back to Good for testing downgrade threshold
|
||||
ctrl.current_tier = Tier::Good;
|
||||
ctrl.current_profile = Tier::Good.profile();
|
||||
|
||||
// On cellular, downgrade threshold is 2 instead of 3
|
||||
let bad = make_report(50.0, 200);
|
||||
assert!(ctrl.observe(&bad).is_none()); // 1st bad
|
||||
let result = ctrl.observe(&bad); // 2nd bad — should trigger on cellular
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_network_change_preemptive_downgrade() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
assert_eq!(ctrl.tier(), Tier::Good);
|
||||
|
||||
// Switch from WiFi to cellular
|
||||
ctrl.network_context = NetworkContext::WiFi;
|
||||
ctrl.signal_network_change(NetworkContext::CellularLte);
|
||||
|
||||
// Should have downgraded one tier: Good -> Degraded
|
||||
assert_eq!(ctrl.tier(), Tier::Degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_network_change_fec_boost() {
|
||||
let mut ctrl = AdaptiveQualityController::new();
|
||||
assert_eq!(ctrl.fec_boost(), 0.0);
|
||||
|
||||
ctrl.signal_network_change(NetworkContext::CellularLte);
|
||||
|
||||
// FEC boost should be active
|
||||
assert!(ctrl.fec_boost() > 0.0);
|
||||
assert_eq!(ctrl.fec_boost(), DEFAULT_FEC_BOOST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tier_downgrade() {
|
||||
assert_eq!(Tier::Good.downgrade(), Some(Tier::Degraded));
|
||||
assert_eq!(Tier::Degraded.downgrade(), Some(Tier::Catastrophic));
|
||||
assert_eq!(Tier::Catastrophic.downgrade(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn network_context_default() {
|
||||
assert_eq!(NetworkContext::default(), NetworkContext::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,11 +20,21 @@ bytes = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
toml = "0.8"
|
||||
anyhow = "1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde_json = "1"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
quinn = { workspace = true }
|
||||
prometheus = "0.13"
|
||||
axum = { version = "0.7", default-features = false, features = ["tokio", "http1", "ws"] }
|
||||
tower-http = { version = "0.6", features = ["fs"] }
|
||||
futures-util = "0.3"
|
||||
dirs = "6"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-relay"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
|
||||
wzp-transport = { workspace = true }
|
||||
wzp-client = { workspace = true }
|
||||
|
||||
106
crates/wzp-relay/src/auth.rs
Normal file
106
crates/wzp-relay/src/auth.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! featherChat token authentication.
|
||||
//!
|
||||
//! When `--auth-url` is configured, the relay validates bearer tokens
|
||||
//! against featherChat's `POST /v1/auth/validate` endpoint before
|
||||
//! allowing clients to join rooms.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Request body for featherChat token validation.
|
||||
#[derive(Serialize)]
|
||||
struct ValidateRequest {
|
||||
token: String,
|
||||
}
|
||||
|
||||
/// Response from featherChat token validation.
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ValidateResponse {
|
||||
pub valid: bool,
|
||||
pub fingerprint: Option<String>,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
/// Validated client identity.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthenticatedClient {
|
||||
pub fingerprint: String,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
/// Validate a bearer token against featherChat's auth endpoint.
|
||||
///
|
||||
/// Calls `POST {auth_url}` with `{ "token": "..." }`.
|
||||
/// Returns the client identity if valid, or an error string.
|
||||
pub async fn validate_token(
|
||||
auth_url: &str,
|
||||
token: &str,
|
||||
) -> Result<AuthenticatedClient, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.map_err(|e| format!("http client error: {e}"))?;
|
||||
|
||||
let resp = client
|
||||
.post(auth_url)
|
||||
.json(&ValidateRequest {
|
||||
token: token.to_string(),
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("auth request failed: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("auth endpoint returned {}", resp.status()));
|
||||
}
|
||||
|
||||
let body: ValidateResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("invalid auth response: {e}"))?;
|
||||
|
||||
if body.valid {
|
||||
let fingerprint = body
|
||||
.fingerprint
|
||||
.ok_or_else(|| "valid response missing fingerprint".to_string())?;
|
||||
info!(%fingerprint, alias = ?body.alias, "token validated");
|
||||
Ok(AuthenticatedClient {
|
||||
fingerprint,
|
||||
alias: body.alias,
|
||||
})
|
||||
} else {
|
||||
warn!("token validation failed");
|
||||
Err("invalid token".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn validate_request_serializes() {
|
||||
let req = ValidateRequest {
|
||||
token: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("abc123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_response_deserializes() {
|
||||
let json = r#"{"valid": true, "fingerprint": "abcd1234", "alias": "manwe"}"#;
|
||||
let resp: ValidateResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.valid);
|
||||
assert_eq!(resp.fingerprint.unwrap(), "abcd1234");
|
||||
assert_eq!(resp.alias.unwrap(), "manwe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_response_deserializes() {
|
||||
let json = r#"{"valid": false}"#;
|
||||
let resp: ValidateResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(!resp.valid);
|
||||
assert!(resp.fingerprint.is_none());
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,31 @@ pub struct RelayConfig {
|
||||
pub jitter_max_depth: usize,
|
||||
/// Logging level (trace, debug, info, warn, error).
|
||||
pub log_level: String,
|
||||
/// featherChat auth validation URL (e.g., "https://chat.example.com/v1/auth/validate").
|
||||
/// If set, clients must present a valid token before joining rooms.
|
||||
pub auth_url: Option<String>,
|
||||
/// Port for the Prometheus metrics HTTP endpoint (e.g., 9090).
|
||||
/// If None, the metrics endpoint is disabled.
|
||||
pub metrics_port: Option<u16>,
|
||||
/// Peer relay addresses to probe for health monitoring.
|
||||
/// Each target gets a persistent QUIC connection sending 1 Ping/s.
|
||||
#[serde(default)]
|
||||
pub probe_targets: Vec<SocketAddr>,
|
||||
/// Enable mesh mode: each relay probes all configured targets concurrently.
|
||||
/// Discovery is manual via multiple --probe flags; this flag signals intent.
|
||||
#[serde(default)]
|
||||
pub probe_mesh: bool,
|
||||
/// Enable trunk batching for outgoing media in room mode.
|
||||
/// When true, packets destined for the same receiver are accumulated into
|
||||
/// [`TrunkFrame`]s and flushed every 5 ms (or when the batcher is full),
|
||||
/// reducing per-packet QUIC datagram overhead.
|
||||
#[serde(default)]
|
||||
pub trunking_enabled: bool,
|
||||
/// Port for the WebSocket listener (browser clients connect here).
|
||||
/// If None, WebSocket support is disabled.
|
||||
pub ws_port: Option<u16>,
|
||||
/// Directory to serve static files from (HTML/JS/WASM for web clients).
|
||||
pub static_dir: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for RelayConfig {
|
||||
@@ -30,6 +55,13 @@ impl Default for RelayConfig {
|
||||
jitter_target_depth: 50,
|
||||
jitter_max_depth: 250,
|
||||
log_level: "info".to_string(),
|
||||
auth_url: None,
|
||||
metrics_port: None,
|
||||
probe_targets: Vec::new(),
|
||||
probe_mesh: false,
|
||||
trunking_enabled: false,
|
||||
ws_port: None,
|
||||
static_dir: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,25 +15,27 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
/// 5. Derive shared ChaCha20-Poly1305 session
|
||||
/// 6. Send `CallAnswer` back
|
||||
///
|
||||
/// Returns the derived `CryptoSession` and the chosen `QualityProfile`.
|
||||
/// Returns the derived `CryptoSession`, the chosen `QualityProfile`, the caller's fingerprint,
|
||||
/// and the caller's alias (if provided in CallOffer).
|
||||
pub async fn accept_handshake(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
) -> Result<(Box<dyn CryptoSession>, QualityProfile), anyhow::Error> {
|
||||
) -> Result<(Box<dyn CryptoSession>, QualityProfile, String, Option<String>), anyhow::Error> {
|
||||
// 1. Receive CallOffer
|
||||
let offer = transport
|
||||
.recv_signal()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallOffer"))?;
|
||||
|
||||
let (caller_identity_pub, caller_ephemeral_pub, caller_signature, supported_profiles) =
|
||||
let (caller_identity_pub, caller_ephemeral_pub, caller_signature, supported_profiles, caller_alias) =
|
||||
match offer {
|
||||
SignalMessage::CallOffer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles,
|
||||
} => (identity_pub, ephemeral_pub, signature, supported_profiles),
|
||||
alias,
|
||||
} => (identity_pub, ephemeral_pub, signature, supported_profiles, alias),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallOffer, got {:?}",
|
||||
@@ -76,7 +78,13 @@ pub async fn accept_handshake(
|
||||
};
|
||||
transport.send_signal(&answer).await?;
|
||||
|
||||
Ok((session, chosen_profile))
|
||||
// Derive caller fingerprint from their identity public key (first 8 bytes as hex)
|
||||
let caller_fp = caller_identity_pub[..8]
|
||||
.iter()
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect::<String>();
|
||||
|
||||
Ok((session, chosen_profile, caller_fp, caller_alias))
|
||||
}
|
||||
|
||||
/// Select the best quality profile from those the caller supports.
|
||||
|
||||
@@ -7,13 +7,22 @@
|
||||
//! It operates on FEC-protected packets, managing loss recovery and adaptive
|
||||
//! quality transitions.
|
||||
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod handshake;
|
||||
pub mod metrics;
|
||||
pub mod pipeline;
|
||||
pub mod presence;
|
||||
pub mod probe;
|
||||
pub mod relay_link;
|
||||
pub mod room;
|
||||
pub mod route;
|
||||
pub mod session_mgr;
|
||||
pub mod trunk;
|
||||
pub mod ws;
|
||||
|
||||
pub use config::RelayConfig;
|
||||
pub use handshake::accept_handshake;
|
||||
pub use pipeline::{PipelineConfig, PipelineStats, RelayPipeline};
|
||||
pub use session_mgr::{RelaySession, SessionId, SessionManager};
|
||||
pub use session_mgr::{RelaySession, SessionId, SessionInfo, SessionManager, SessionState};
|
||||
pub use trunk::TrunkBatcher;
|
||||
|
||||
@@ -13,12 +13,15 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use wzp_proto::MediaTransport;
|
||||
use wzp_relay::config::RelayConfig;
|
||||
use wzp_relay::metrics::RelayMetrics;
|
||||
use wzp_relay::pipeline::{PipelineConfig, RelayPipeline};
|
||||
use wzp_relay::presence::PresenceRegistry;
|
||||
use wzp_relay::room::{self, RoomManager};
|
||||
use wzp_relay::session_mgr::SessionManager;
|
||||
|
||||
fn parse_args() -> RelayConfig {
|
||||
let mut config = RelayConfig::default();
|
||||
@@ -38,17 +41,72 @@ fn parse_args() -> RelayConfig {
|
||||
.parse().expect("invalid --remote address"),
|
||||
);
|
||||
}
|
||||
"--auth-url" => {
|
||||
i += 1;
|
||||
config.auth_url = Some(
|
||||
args.get(i).expect("--auth-url requires a URL").to_string(),
|
||||
);
|
||||
}
|
||||
"--metrics-port" => {
|
||||
i += 1;
|
||||
config.metrics_port = Some(
|
||||
args.get(i).expect("--metrics-port requires a port number")
|
||||
.parse().expect("invalid --metrics-port number"),
|
||||
);
|
||||
}
|
||||
"--probe" => {
|
||||
i += 1;
|
||||
let addr: SocketAddr = args.get(i)
|
||||
.expect("--probe requires an address")
|
||||
.parse()
|
||||
.expect("invalid --probe address");
|
||||
config.probe_targets.push(addr);
|
||||
}
|
||||
"--probe-mesh" => {
|
||||
config.probe_mesh = true;
|
||||
}
|
||||
"--trunking" => {
|
||||
config.trunking_enabled = true;
|
||||
}
|
||||
"--ws-port" => {
|
||||
i += 1;
|
||||
config.ws_port = Some(
|
||||
args.get(i).expect("--ws-port requires a port number")
|
||||
.parse().expect("invalid --ws-port number"),
|
||||
);
|
||||
}
|
||||
"--static-dir" => {
|
||||
i += 1;
|
||||
config.static_dir = Some(
|
||||
args.get(i).expect("--static-dir requires a directory path").to_string(),
|
||||
);
|
||||
}
|
||||
"--mesh-status" => {
|
||||
// Print mesh table from a fresh registry and exit.
|
||||
// In practice this is useful after the relay has been running;
|
||||
// here we just demonstrate the formatter with an empty registry.
|
||||
let m = RelayMetrics::new();
|
||||
print!("{}", wzp_relay::probe::mesh_summary(m.registry()));
|
||||
std::process::exit(0);
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
eprintln!("Usage: wzp-relay [--listen <addr>] [--remote <addr>]");
|
||||
eprintln!("Usage: wzp-relay [--listen <addr>] [--remote <addr>] [--auth-url <url>] [--metrics-port <port>] [--probe <addr>]... [--probe-mesh] [--mesh-status]");
|
||||
eprintln!();
|
||||
eprintln!("Options:");
|
||||
eprintln!(" --listen <addr> Listen address (default: 0.0.0.0:4433)");
|
||||
eprintln!(" --remote <addr> Remote relay for forwarding (disables room mode)");
|
||||
eprintln!(" --listen <addr> Listen address (default: 0.0.0.0:4433)");
|
||||
eprintln!(" --remote <addr> Remote relay for forwarding (disables room mode)");
|
||||
eprintln!(" --auth-url <url> featherChat auth endpoint (e.g., https://chat.example.com/v1/auth/validate)");
|
||||
eprintln!(" When set, clients must send a bearer token as first signal message.");
|
||||
eprintln!(" --metrics-port <port> Prometheus metrics HTTP port (e.g., 9090). Disabled if not set.");
|
||||
eprintln!(" --probe <addr> Peer relay to probe for health monitoring (repeatable).");
|
||||
eprintln!(" --probe-mesh Enable mesh mode (mark config flag, probes all --probe targets).");
|
||||
eprintln!(" --mesh-status Print mesh health table and exit (diagnostic).");
|
||||
eprintln!(" --trunking Enable trunk batching for outgoing media in room mode.");
|
||||
eprintln!(" --ws-port <port> WebSocket listener port for browser clients (e.g., 8080).");
|
||||
eprintln!(" --static-dir <dir> Directory to serve static files from (HTML/JS/WASM).");
|
||||
eprintln!();
|
||||
eprintln!("Room mode (default):");
|
||||
eprintln!(" Clients join rooms by name. Packets are forwarded to all");
|
||||
eprintln!(" other participants in the same room (SFU model).");
|
||||
eprintln!(" Room name comes from QUIC SNI or defaults to 'default'.");
|
||||
eprintln!(" Clients join rooms by name. Packets forwarded to all others (SFU).");
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => {
|
||||
@@ -134,7 +192,56 @@ async fn main() -> anyhow::Result<()> {
|
||||
.install_default()
|
||||
.expect("failed to install rustls crypto provider");
|
||||
|
||||
info!(addr = %config.listen_addr, "WarzonePhone relay starting");
|
||||
// Presence registry
|
||||
let presence = Arc::new(Mutex::new(PresenceRegistry::new()));
|
||||
|
||||
// Route resolver
|
||||
let route_resolver = Arc::new(wzp_relay::route::RouteResolver::new(config.listen_addr));
|
||||
|
||||
// Prometheus metrics
|
||||
let metrics = Arc::new(RelayMetrics::new());
|
||||
if let Some(port) = config.metrics_port {
|
||||
let m = metrics.clone();
|
||||
let p = Some(presence.clone());
|
||||
let rr = Some(route_resolver.clone());
|
||||
tokio::spawn(wzp_relay::metrics::serve_metrics(port, m, p, rr));
|
||||
}
|
||||
|
||||
// Load or generate relay identity — persisted in ~/.wzp/relay-identity
|
||||
let relay_seed = {
|
||||
let config_dir = dirs::home_dir()
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("."))
|
||||
.join(".wzp");
|
||||
let identity_path = config_dir.join("relay-identity");
|
||||
if identity_path.exists() {
|
||||
if let Ok(hex) = std::fs::read_to_string(&identity_path) {
|
||||
if let Ok(s) = wzp_crypto::Seed::from_hex(hex.trim()) {
|
||||
info!("loaded relay identity from {}", identity_path.display());
|
||||
s
|
||||
} else {
|
||||
warn!("corrupt relay identity file, generating new");
|
||||
let s = wzp_crypto::Seed::generate();
|
||||
let hex: String = s.0.iter().map(|b| format!("{b:02x}")).collect();
|
||||
let _ = std::fs::write(&identity_path, &hex);
|
||||
s
|
||||
}
|
||||
} else {
|
||||
let s = wzp_crypto::Seed::generate();
|
||||
let hex: String = s.0.iter().map(|b| format!("{b:02x}")).collect();
|
||||
let _ = std::fs::write(&identity_path, &hex);
|
||||
s
|
||||
}
|
||||
} else {
|
||||
let s = wzp_crypto::Seed::generate();
|
||||
let _ = std::fs::create_dir_all(&config_dir);
|
||||
let hex: String = s.0.iter().map(|b| format!("{b:02x}")).collect();
|
||||
let _ = std::fs::write(&identity_path, &hex);
|
||||
info!("generated relay identity at {}", identity_path.display());
|
||||
s
|
||||
}
|
||||
};
|
||||
let relay_fp = relay_seed.derive_identity().public_identity().fingerprint;
|
||||
info!(addr = %config.listen_addr, fingerprint = %relay_fp, "WarzonePhone relay starting");
|
||||
|
||||
let (server_config, _cert) = wzp_transport::server_config();
|
||||
let endpoint = wzp_transport::create_endpoint(config.listen_addr, Some(server_config))?;
|
||||
@@ -154,6 +261,44 @@ async fn main() -> anyhow::Result<()> {
|
||||
// Room manager (room mode only)
|
||||
let room_mgr = Arc::new(Mutex::new(RoomManager::new()));
|
||||
|
||||
// Session manager — enforces max concurrent sessions
|
||||
let session_mgr = Arc::new(Mutex::new(SessionManager::new(config.max_sessions)));
|
||||
|
||||
// Spawn inter-relay health probes via ProbeMesh coordinator
|
||||
if !config.probe_targets.is_empty() {
|
||||
let mesh = wzp_relay::probe::ProbeMesh::new(
|
||||
config.probe_targets.clone(),
|
||||
metrics.registry(),
|
||||
Some(presence.clone()),
|
||||
);
|
||||
info!(
|
||||
targets = mesh.target_count(),
|
||||
mesh = config.probe_mesh,
|
||||
"spawning probe mesh"
|
||||
);
|
||||
tokio::spawn(async move { mesh.run_all().await });
|
||||
}
|
||||
|
||||
// WebSocket server for browser clients
|
||||
if let Some(ws_port) = config.ws_port {
|
||||
let ws_state = wzp_relay::ws::WsState {
|
||||
room_mgr: room_mgr.clone(),
|
||||
session_mgr: session_mgr.clone(),
|
||||
auth_url: config.auth_url.clone(),
|
||||
metrics: metrics.clone(),
|
||||
presence: presence.clone(),
|
||||
};
|
||||
let static_dir = config.static_dir.clone();
|
||||
tokio::spawn(wzp_relay::ws::run_ws_server(ws_port, ws_state, static_dir));
|
||||
info!(ws_port, "WebSocket listener enabled for browser clients");
|
||||
}
|
||||
|
||||
if let Some(ref url) = config.auth_url {
|
||||
info!(url, "auth enabled — clients must present featherChat token");
|
||||
} else {
|
||||
info!("auth disabled — any client can connect (use --auth-url to enable)");
|
||||
}
|
||||
|
||||
info!("Listening for connections...");
|
||||
|
||||
loop {
|
||||
@@ -164,12 +309,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let remote_transport = remote_transport.clone();
|
||||
let room_mgr = room_mgr.clone();
|
||||
let session_mgr = session_mgr.clone();
|
||||
let auth_url = config.auth_url.clone();
|
||||
let relay_seed_bytes = relay_seed.0;
|
||||
let metrics = metrics.clone();
|
||||
let trunking_enabled = config.trunking_enabled;
|
||||
let presence = presence.clone();
|
||||
let route_resolver = route_resolver.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let addr = connection.remote_address();
|
||||
|
||||
// Extract room name from QUIC handshake data (SNI).
|
||||
// The web bridge connects with the room name as server_name.
|
||||
let room_name = connection
|
||||
.handshake_data()
|
||||
.and_then(|hd| {
|
||||
@@ -180,7 +330,172 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(connection));
|
||||
|
||||
info!(%addr, room = %room_name, "new client");
|
||||
// Ping connections: client just measures QUIC connect RTT.
|
||||
// No handshake, no streams — client closes immediately after connecting.
|
||||
if room_name == "ping" {
|
||||
info!(%addr, "ping connection (RTT probe)");
|
||||
return;
|
||||
}
|
||||
|
||||
// Probe connections use SNI "_probe" to identify themselves.
|
||||
// They skip auth + handshake and just do Ping->Pong + presence gossip.
|
||||
if room_name == "_probe" {
|
||||
info!(%addr, "probe connection detected, entering Ping/Pong + presence responder");
|
||||
loop {
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(wzp_proto::SignalMessage::Ping { timestamp_ms })) => {
|
||||
if let Err(e) = transport.send_signal(
|
||||
&wzp_proto::SignalMessage::Pong { timestamp_ms },
|
||||
).await {
|
||||
error!(%addr, "probe pong send error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(wzp_proto::SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => {
|
||||
// A peer relay is telling us which fingerprints it has
|
||||
let peer_addr: std::net::SocketAddr = relay_addr.parse().unwrap_or(addr);
|
||||
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect();
|
||||
{
|
||||
let mut reg = presence.lock().await;
|
||||
reg.update_peer(peer_addr, fps);
|
||||
}
|
||||
// Reply with our own local fingerprints
|
||||
let local_fps: Vec<String> = {
|
||||
let reg = presence.lock().await;
|
||||
reg.local_fingerprints().into_iter().collect()
|
||||
};
|
||||
let reply = wzp_proto::SignalMessage::PresenceUpdate {
|
||||
fingerprints: local_fps,
|
||||
relay_addr: addr.to_string(),
|
||||
};
|
||||
if let Err(e) = transport.send_signal(&reply).await {
|
||||
error!(%addr, "presence reply send error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(wzp_proto::SignalMessage::RouteQuery { fingerprint, ttl })) => {
|
||||
// Look up the fingerprint in our local registry
|
||||
let reg = presence.lock().await;
|
||||
let route = route_resolver.resolve(®, &fingerprint);
|
||||
drop(reg);
|
||||
|
||||
let (found, relay_chain) = match route {
|
||||
wzp_relay::route::Route::Local => {
|
||||
(true, vec![route_resolver.local_addr().to_string()])
|
||||
}
|
||||
wzp_relay::route::Route::DirectPeer(peer_addr) => {
|
||||
(true, vec![route_resolver.local_addr().to_string(), peer_addr.to_string()])
|
||||
}
|
||||
_ => {
|
||||
// Not found locally; if ttl > 0 we could forward
|
||||
// to other peers (future multi-hop). For now, reply not found.
|
||||
if ttl > 0 {
|
||||
// TODO: forward RouteQuery to other peers with ttl-1
|
||||
}
|
||||
(false, vec![])
|
||||
}
|
||||
};
|
||||
|
||||
let reply = wzp_proto::SignalMessage::RouteResponse {
|
||||
fingerprint,
|
||||
found,
|
||||
relay_chain,
|
||||
};
|
||||
if let Err(e) = transport.send_signal(&reply).await {
|
||||
error!(%addr, "route response send error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(_)) => {
|
||||
// Ignore other signals on probe connections
|
||||
}
|
||||
Ok(None) => {
|
||||
info!(%addr, "probe connection closed");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, "probe recv error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
|
||||
// Auth check: if --auth-url is set, expect first signal message to be a token
|
||||
// Auth: if --auth-url is set, expect AuthToken as first signal
|
||||
let authenticated_fp: Option<String> = if let Some(ref url) = auth_url {
|
||||
info!(%addr, "waiting for auth token...");
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(wzp_proto::SignalMessage::AuthToken { token })) => {
|
||||
match wzp_relay::auth::validate_token(url, &token).await {
|
||||
Ok(client) => {
|
||||
metrics.auth_attempts.with_label_values(&["ok"]).inc();
|
||||
info!(
|
||||
%addr,
|
||||
fingerprint = %client.fingerprint,
|
||||
alias = ?client.alias,
|
||||
"authenticated"
|
||||
);
|
||||
Some(client.fingerprint)
|
||||
}
|
||||
Err(e) => {
|
||||
metrics.auth_attempts.with_label_values(&["fail"]).inc();
|
||||
error!(%addr, "auth failed: {e}");
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(_)) => {
|
||||
error!(%addr, "expected AuthToken as first signal, got something else");
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
error!(%addr, "connection closed before auth");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, "signal recv error during auth: {e}");
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Crypto handshake: verify client identity + negotiate quality profile
|
||||
let handshake_start = std::time::Instant::now();
|
||||
let (_crypto_session, _chosen_profile, caller_fp, caller_alias) = match wzp_relay::handshake::accept_handshake(
|
||||
&*transport,
|
||||
&relay_seed_bytes,
|
||||
).await {
|
||||
Ok(result) => {
|
||||
let elapsed = handshake_start.elapsed().as_secs_f64();
|
||||
metrics.handshake_duration.observe(elapsed);
|
||||
info!(%addr, elapsed_ms = %(elapsed * 1000.0), "crypto handshake complete");
|
||||
result
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, "handshake failed: {e}");
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Use the caller's identity fingerprint from the handshake
|
||||
let participant_fp = authenticated_fp.clone().unwrap_or(caller_fp);
|
||||
|
||||
// Register in presence registry
|
||||
{
|
||||
let mut reg = presence.lock().await;
|
||||
reg.register_local(&participant_fp, None, Some(room_name.clone()));
|
||||
}
|
||||
|
||||
info!(%addr, room = %room_name, "client joining");
|
||||
|
||||
if let Some(remote) = remote_transport {
|
||||
// Forward mode — same as before
|
||||
@@ -211,19 +526,77 @@ async fn main() -> anyhow::Result<()> {
|
||||
stats_handle.abort();
|
||||
transport.close().await.ok();
|
||||
} else {
|
||||
// Room mode — join room and forward to all others
|
||||
let participant_id = {
|
||||
let mut mgr = room_mgr.lock().await;
|
||||
mgr.join(&room_name, addr, transport.clone())
|
||||
// Room mode — enforce max sessions, then join room
|
||||
let session_id = {
|
||||
let mut smgr = session_mgr.lock().await;
|
||||
match smgr.create_session(&room_name, authenticated_fp.clone()) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
error!(%addr, room = %room_name, "session rejected: {e}");
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
metrics.active_sessions.inc();
|
||||
|
||||
let participant_id = {
|
||||
let mut mgr = room_mgr.lock().await;
|
||||
match mgr.join(
|
||||
&room_name,
|
||||
addr,
|
||||
room::ParticipantSender::Quic(transport.clone()),
|
||||
Some(&participant_fp),
|
||||
caller_alias.as_deref(),
|
||||
) {
|
||||
Ok((id, update, senders)) => {
|
||||
metrics.active_rooms.set(mgr.list().len() as i64);
|
||||
drop(mgr); // release lock before async broadcast
|
||||
room::broadcast_signal(&senders, &update).await;
|
||||
id
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, room = %room_name, "room join denied: {e}");
|
||||
metrics.active_sessions.dec();
|
||||
let mut smgr = session_mgr.lock().await;
|
||||
smgr.remove_session(session_id);
|
||||
transport.close().await.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let session_id_str: String = session_id
|
||||
.iter()
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect();
|
||||
room::run_participant(
|
||||
room_mgr.clone(),
|
||||
room_name,
|
||||
participant_id,
|
||||
transport.clone(),
|
||||
metrics.clone(),
|
||||
&session_id_str,
|
||||
trunking_enabled,
|
||||
).await;
|
||||
|
||||
// Participant disconnected — clean up presence + per-session metrics
|
||||
if let Some(ref fp) = authenticated_fp {
|
||||
let mut reg = presence.lock().await;
|
||||
reg.unregister_local(fp);
|
||||
}
|
||||
metrics.remove_session_metrics(&session_id_str);
|
||||
metrics.active_sessions.dec();
|
||||
{
|
||||
let mgr = room_mgr.lock().await;
|
||||
metrics.active_rooms.set(mgr.list().len() as i64);
|
||||
}
|
||||
{
|
||||
let mut smgr = session_mgr.lock().await;
|
||||
smgr.remove_session(session_id);
|
||||
}
|
||||
|
||||
transport.close().await.ok();
|
||||
}
|
||||
});
|
||||
|
||||
412
crates/wzp-relay/src/metrics.rs
Normal file
412
crates/wzp-relay/src/metrics.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Prometheus metrics for the WZP relay daemon.
|
||||
|
||||
use prometheus::{
|
||||
Encoder, GaugeVec, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
|
||||
Opts, Registry, TextEncoder,
|
||||
};
|
||||
use wzp_proto::packet::QualityReport;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// All relay-level Prometheus metrics.
|
||||
#[derive(Clone)]
|
||||
pub struct RelayMetrics {
|
||||
pub active_sessions: IntGauge,
|
||||
pub active_rooms: IntGauge,
|
||||
pub packets_forwarded: IntCounter,
|
||||
pub bytes_forwarded: IntCounter,
|
||||
pub auth_attempts: IntCounterVec,
|
||||
pub handshake_duration: Histogram,
|
||||
// Per-session metrics
|
||||
pub session_buffer_depth: IntGaugeVec,
|
||||
pub session_loss_pct: GaugeVec,
|
||||
pub session_rtt_ms: GaugeVec,
|
||||
pub session_underruns: IntCounterVec,
|
||||
pub session_overruns: IntCounterVec,
|
||||
registry: Registry,
|
||||
}
|
||||
|
||||
impl RelayMetrics {
|
||||
/// Create and register all relay metrics with a new registry.
|
||||
pub fn new() -> Self {
|
||||
let registry = Registry::new();
|
||||
|
||||
let active_sessions = IntGauge::with_opts(
|
||||
Opts::new("wzp_relay_active_sessions", "Current active sessions"),
|
||||
)
|
||||
.expect("metric");
|
||||
let active_rooms = IntGauge::with_opts(
|
||||
Opts::new("wzp_relay_active_rooms", "Current active rooms"),
|
||||
)
|
||||
.expect("metric");
|
||||
let packets_forwarded = IntCounter::with_opts(
|
||||
Opts::new("wzp_relay_packets_forwarded_total", "Total packets forwarded"),
|
||||
)
|
||||
.expect("metric");
|
||||
let bytes_forwarded = IntCounter::with_opts(
|
||||
Opts::new("wzp_relay_bytes_forwarded_total", "Total bytes forwarded"),
|
||||
)
|
||||
.expect("metric");
|
||||
let auth_attempts = IntCounterVec::new(
|
||||
Opts::new("wzp_relay_auth_attempts_total", "Auth validation attempts"),
|
||||
&["result"],
|
||||
)
|
||||
.expect("metric");
|
||||
let handshake_duration = Histogram::with_opts(
|
||||
HistogramOpts::new(
|
||||
"wzp_relay_handshake_duration_seconds",
|
||||
"Crypto handshake time",
|
||||
)
|
||||
.buckets(vec![0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]),
|
||||
)
|
||||
.expect("metric");
|
||||
|
||||
let session_buffer_depth = IntGaugeVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_jitter_buffer_depth",
|
||||
"Buffer depth per session",
|
||||
),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
let session_loss_pct = GaugeVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_loss_pct",
|
||||
"Packet loss percentage per session",
|
||||
),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
let session_rtt_ms = GaugeVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_rtt_ms",
|
||||
"Round-trip time per session",
|
||||
),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
let session_underruns = IntCounterVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_underruns_total",
|
||||
"Jitter buffer underruns per session",
|
||||
),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
let session_overruns = IntCounterVec::new(
|
||||
Opts::new(
|
||||
"wzp_relay_session_overruns_total",
|
||||
"Jitter buffer overruns per session",
|
||||
),
|
||||
&["session_id"],
|
||||
)
|
||||
.expect("metric");
|
||||
|
||||
registry.register(Box::new(active_sessions.clone())).expect("register");
|
||||
registry.register(Box::new(active_rooms.clone())).expect("register");
|
||||
registry.register(Box::new(packets_forwarded.clone())).expect("register");
|
||||
registry.register(Box::new(bytes_forwarded.clone())).expect("register");
|
||||
registry.register(Box::new(auth_attempts.clone())).expect("register");
|
||||
registry.register(Box::new(handshake_duration.clone())).expect("register");
|
||||
registry.register(Box::new(session_buffer_depth.clone())).expect("register");
|
||||
registry.register(Box::new(session_loss_pct.clone())).expect("register");
|
||||
registry.register(Box::new(session_rtt_ms.clone())).expect("register");
|
||||
registry.register(Box::new(session_underruns.clone())).expect("register");
|
||||
registry.register(Box::new(session_overruns.clone())).expect("register");
|
||||
|
||||
Self {
|
||||
active_sessions,
|
||||
active_rooms,
|
||||
packets_forwarded,
|
||||
bytes_forwarded,
|
||||
auth_attempts,
|
||||
handshake_duration,
|
||||
session_buffer_depth,
|
||||
session_loss_pct,
|
||||
session_rtt_ms,
|
||||
session_underruns,
|
||||
session_overruns,
|
||||
registry,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update per-session quality metrics from a QualityReport.
|
||||
pub fn update_session_quality(&self, session_id: &str, report: &QualityReport) {
|
||||
self.session_loss_pct
|
||||
.with_label_values(&[session_id])
|
||||
.set(report.loss_percent() as f64);
|
||||
self.session_rtt_ms
|
||||
.with_label_values(&[session_id])
|
||||
.set(report.rtt_ms() as f64);
|
||||
}
|
||||
|
||||
/// Update per-session buffer metrics.
|
||||
pub fn update_session_buffer(
|
||||
&self,
|
||||
session_id: &str,
|
||||
depth: usize,
|
||||
underruns: u64,
|
||||
overruns: u64,
|
||||
) {
|
||||
self.session_buffer_depth
|
||||
.with_label_values(&[session_id])
|
||||
.set(depth as i64);
|
||||
// IntCounterVec doesn't have a `set` — we inc by the delta.
|
||||
// Since these are cumulative from the jitter buffer, we use inc_by
|
||||
// with the current totals. To avoid double-counting, callers should
|
||||
// track previous values externally. For simplicity the relay reports
|
||||
// the absolute value each tick; counters only go up so we take the
|
||||
// max(0, new - current) approach.
|
||||
let cur_underruns = self
|
||||
.session_underruns
|
||||
.with_label_values(&[session_id])
|
||||
.get();
|
||||
if underruns > cur_underruns as u64 {
|
||||
self.session_underruns
|
||||
.with_label_values(&[session_id])
|
||||
.inc_by(underruns - cur_underruns as u64);
|
||||
}
|
||||
let cur_overruns = self
|
||||
.session_overruns
|
||||
.with_label_values(&[session_id])
|
||||
.get();
|
||||
if overruns > cur_overruns as u64 {
|
||||
self.session_overruns
|
||||
.with_label_values(&[session_id])
|
||||
.inc_by(overruns - cur_overruns as u64);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all per-session label values for a disconnected session.
|
||||
pub fn remove_session_metrics(&self, session_id: &str) {
|
||||
let _ = self.session_buffer_depth.remove_label_values(&[session_id]);
|
||||
let _ = self.session_loss_pct.remove_label_values(&[session_id]);
|
||||
let _ = self.session_rtt_ms.remove_label_values(&[session_id]);
|
||||
let _ = self.session_underruns.remove_label_values(&[session_id]);
|
||||
let _ = self.session_overruns.remove_label_values(&[session_id]);
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying Prometheus registry.
|
||||
/// Probe metrics are registered on this same registry so they appear in /metrics output.
|
||||
pub fn registry(&self) -> &Registry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
/// Gather all metrics and encode them as Prometheus text format.
|
||||
pub fn metrics_handler(&self) -> String {
|
||||
let encoder = TextEncoder::new();
|
||||
let metric_families = self.registry.gather();
|
||||
let mut buffer = Vec::new();
|
||||
encoder.encode(&metric_families, &mut buffer).expect("encode");
|
||||
String::from_utf8(buffer).expect("utf8")
|
||||
}
|
||||
}
|
||||
|
||||
/// Start an HTTP server serving GET /metrics, GET /mesh, presence, and route endpoints on the given port.
|
||||
pub async fn serve_metrics(
|
||||
port: u16,
|
||||
metrics: Arc<RelayMetrics>,
|
||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||
route_resolver: Option<Arc<crate::route::RouteResolver>>,
|
||||
) {
|
||||
use axum::{extract::Path, routing::get, Router};
|
||||
|
||||
let metrics_clone = metrics.clone();
|
||||
let presence_all = presence.clone();
|
||||
let presence_lookup = presence.clone();
|
||||
let presence_peers = presence.clone();
|
||||
let presence_route = presence;
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/metrics",
|
||||
get(move || {
|
||||
let m = metrics.clone();
|
||||
async move { m.metrics_handler() }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/mesh",
|
||||
get(move || {
|
||||
let m = metrics_clone.clone();
|
||||
async move { crate::probe::mesh_summary(m.registry()) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/presence",
|
||||
get(move || {
|
||||
let reg = presence_all.clone();
|
||||
async move {
|
||||
match reg {
|
||||
Some(r) => {
|
||||
let r = r.lock().await;
|
||||
let entries: Vec<serde_json::Value> = r.all_known().into_iter().map(|(fp, loc)| {
|
||||
serde_json::json!({ "fingerprint": fp, "location": loc })
|
||||
}).collect();
|
||||
serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string())
|
||||
}
|
||||
None => "[]".to_string(),
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/presence/:fingerprint",
|
||||
get(move |Path(fingerprint): Path<String>| {
|
||||
let reg = presence_lookup.clone();
|
||||
async move {
|
||||
match reg {
|
||||
Some(r) => {
|
||||
let r = r.lock().await;
|
||||
match r.lookup(&fingerprint) {
|
||||
Some(loc) => serde_json::to_string_pretty(
|
||||
&serde_json::json!({ "fingerprint": fingerprint, "location": loc })
|
||||
).unwrap_or_else(|_| "{}".to_string()),
|
||||
None => serde_json::json!({ "fingerprint": fingerprint, "location": null }).to_string(),
|
||||
}
|
||||
}
|
||||
None => serde_json::json!({ "fingerprint": fingerprint, "location": null }).to_string(),
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/peers",
|
||||
get(move || {
|
||||
let reg = presence_peers.clone();
|
||||
async move {
|
||||
match reg {
|
||||
Some(r) => {
|
||||
let r = r.lock().await;
|
||||
let peers: Vec<serde_json::Value> = r.peers().iter().map(|(addr, peer)| {
|
||||
serde_json::json!({
|
||||
"addr": addr.to_string(),
|
||||
"fingerprints": peer.fingerprints.iter().collect::<Vec<_>>(),
|
||||
"rtt_ms": peer.rtt_ms,
|
||||
})
|
||||
}).collect();
|
||||
serde_json::to_string_pretty(&peers).unwrap_or_else(|_| "[]".to_string())
|
||||
}
|
||||
None => "[]".to_string(),
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/route/:fingerprint",
|
||||
get(move |Path(fingerprint): Path<String>| {
|
||||
let reg = presence_route.clone();
|
||||
let resolver = route_resolver.clone();
|
||||
async move {
|
||||
match (reg, resolver) {
|
||||
(Some(r), Some(res)) => {
|
||||
let r = r.lock().await;
|
||||
let route = res.resolve(&r, &fingerprint);
|
||||
let json = res.route_json(&fingerprint, &route);
|
||||
serde_json::to_string_pretty(&json)
|
||||
.unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
_ => {
|
||||
serde_json::json!({
|
||||
"fingerprint": fingerprint,
|
||||
"route": "not_found",
|
||||
"relay_chain": [],
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.expect("failed to bind metrics port");
|
||||
tracing::info!(%addr, "metrics endpoint serving");
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("metrics server error");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn metrics_register() {
|
||||
let m = RelayMetrics::new();
|
||||
// Touch the CounterVec labels so they appear in output
|
||||
m.auth_attempts.with_label_values(&["ok"]);
|
||||
m.auth_attempts.with_label_values(&["fail"]);
|
||||
let output = m.metrics_handler();
|
||||
// Should contain all registered metric names (as HELP or TYPE lines)
|
||||
assert!(output.contains("wzp_relay_active_sessions"));
|
||||
assert!(output.contains("wzp_relay_active_rooms"));
|
||||
assert!(output.contains("wzp_relay_packets_forwarded_total"));
|
||||
assert!(output.contains("wzp_relay_bytes_forwarded_total"));
|
||||
assert!(output.contains("wzp_relay_auth_attempts_total"));
|
||||
assert!(output.contains("wzp_relay_handshake_duration_seconds"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_quality_update() {
|
||||
let m = RelayMetrics::new();
|
||||
let report = QualityReport {
|
||||
loss_pct: 128, // ~50%
|
||||
rtt_4ms: 25, // 100ms
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
};
|
||||
m.update_session_quality("sess-abc", &report);
|
||||
|
||||
let output = m.metrics_handler();
|
||||
assert!(output.contains("wzp_relay_session_loss_pct{session_id=\"sess-abc\"}"));
|
||||
assert!(output.contains("wzp_relay_session_rtt_ms{session_id=\"sess-abc\"}"));
|
||||
// Verify rtt value (25 * 4 = 100)
|
||||
assert!(output.contains("wzp_relay_session_rtt_ms{session_id=\"sess-abc\"} 100"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_metrics_cleanup() {
|
||||
let m = RelayMetrics::new();
|
||||
let report = QualityReport {
|
||||
loss_pct: 50,
|
||||
rtt_4ms: 10,
|
||||
jitter_ms: 5,
|
||||
bitrate_cap_kbps: 100,
|
||||
};
|
||||
m.update_session_quality("sess-cleanup", &report);
|
||||
m.update_session_buffer("sess-cleanup", 42, 3, 1);
|
||||
|
||||
// Verify they appear
|
||||
let output = m.metrics_handler();
|
||||
assert!(output.contains("sess-cleanup"));
|
||||
|
||||
// Remove and verify they are gone
|
||||
m.remove_session_metrics("sess-cleanup");
|
||||
let output = m.metrics_handler();
|
||||
assert!(!output.contains("sess-cleanup"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_increment() {
|
||||
let m = RelayMetrics::new();
|
||||
|
||||
m.active_sessions.set(5);
|
||||
m.active_rooms.set(2);
|
||||
m.packets_forwarded.inc_by(100);
|
||||
m.bytes_forwarded.inc_by(48000);
|
||||
m.auth_attempts.with_label_values(&["ok"]).inc();
|
||||
m.auth_attempts.with_label_values(&["fail"]).inc_by(3);
|
||||
m.handshake_duration.observe(0.042);
|
||||
|
||||
let output = m.metrics_handler();
|
||||
assert!(output.contains("wzp_relay_active_sessions 5"));
|
||||
assert!(output.contains("wzp_relay_active_rooms 2"));
|
||||
assert!(output.contains("wzp_relay_packets_forwarded_total 100"));
|
||||
assert!(output.contains("wzp_relay_bytes_forwarded_total 48000"));
|
||||
assert!(output.contains("wzp_relay_auth_attempts_total{result=\"ok\"} 1"));
|
||||
assert!(output.contains("wzp_relay_auth_attempts_total{result=\"fail\"} 3"));
|
||||
assert!(output.contains("wzp_relay_handshake_duration_seconds_count 1"));
|
||||
}
|
||||
}
|
||||
333
crates/wzp-relay/src/presence.rs
Normal file
333
crates/wzp-relay/src/presence.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
//! Presence registry — tracks which fingerprints are connected to this relay
|
||||
//! and to peer relays (via gossip over probe connections).
|
||||
//!
|
||||
//! This enables route resolution: given a fingerprint, determine whether the
|
||||
//! user is local, on a known peer relay, or unknown.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data structures
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Where a fingerprint is connected.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
|
||||
pub enum PresenceLocation {
|
||||
/// Connected directly to this relay.
|
||||
Local,
|
||||
/// Connected to a peer relay at the given address.
|
||||
Remote(SocketAddr),
|
||||
}
|
||||
|
||||
/// Presence entry for a fingerprint connected directly to this relay.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LocalPresence {
|
||||
pub fingerprint: String,
|
||||
pub alias: Option<String>,
|
||||
pub connected_at: Instant,
|
||||
pub room: Option<String>,
|
||||
}
|
||||
|
||||
/// Presence entry for a fingerprint reported by a peer relay.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RemotePresence {
|
||||
pub fingerprint: String,
|
||||
pub relay_addr: SocketAddr,
|
||||
pub last_seen: Instant,
|
||||
}
|
||||
|
||||
/// Known peer relay and its reported fingerprints.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PeerRelay {
|
||||
pub addr: SocketAddr,
|
||||
pub fingerprints: HashSet<String>,
|
||||
pub last_update: Instant,
|
||||
pub rtt_ms: Option<f64>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Central presence registry tracking local and remote fingerprints.
|
||||
pub struct PresenceRegistry {
|
||||
/// Fingerprints connected directly to THIS relay.
|
||||
local: HashMap<String, LocalPresence>,
|
||||
/// Fingerprints reported by peer relays (via gossip).
|
||||
remote: HashMap<String, RemotePresence>,
|
||||
/// Known peer relays and their status.
|
||||
peers: HashMap<SocketAddr, PeerRelay>,
|
||||
}
|
||||
|
||||
impl PresenceRegistry {
|
||||
/// Create an empty registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
local: HashMap::new(),
|
||||
remote: HashMap::new(),
|
||||
peers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a fingerprint as locally connected (called after auth + handshake).
|
||||
pub fn register_local(&mut self, fingerprint: &str, alias: Option<String>, room: Option<String>) {
|
||||
self.local.insert(fingerprint.to_string(), LocalPresence {
|
||||
fingerprint: fingerprint.to_string(),
|
||||
alias,
|
||||
connected_at: Instant::now(),
|
||||
room,
|
||||
});
|
||||
}
|
||||
|
||||
/// Unregister a locally connected fingerprint (called on disconnect).
|
||||
pub fn unregister_local(&mut self, fingerprint: &str) {
|
||||
self.local.remove(fingerprint);
|
||||
}
|
||||
|
||||
/// Update the fingerprints reported by a peer relay.
|
||||
/// Replaces the previous set for that peer.
|
||||
pub fn update_peer(&mut self, addr: SocketAddr, fingerprints: HashSet<String>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Remove old remote entries that belonged to this peer
|
||||
self.remote.retain(|_, rp| rp.relay_addr != addr);
|
||||
|
||||
// Insert new remote entries
|
||||
for fp in &fingerprints {
|
||||
self.remote.insert(fp.clone(), RemotePresence {
|
||||
fingerprint: fp.clone(),
|
||||
relay_addr: addr,
|
||||
last_seen: now,
|
||||
});
|
||||
}
|
||||
|
||||
// Update the peer record
|
||||
let peer = self.peers.entry(addr).or_insert_with(|| PeerRelay {
|
||||
addr,
|
||||
fingerprints: HashSet::new(),
|
||||
last_update: now,
|
||||
rtt_ms: None,
|
||||
});
|
||||
peer.fingerprints = fingerprints;
|
||||
peer.last_update = now;
|
||||
}
|
||||
|
||||
/// Look up where a fingerprint is connected.
|
||||
/// Local presence takes priority over remote.
|
||||
pub fn lookup(&self, fingerprint: &str) -> Option<PresenceLocation> {
|
||||
if self.local.contains_key(fingerprint) {
|
||||
return Some(PresenceLocation::Local);
|
||||
}
|
||||
if let Some(rp) = self.remote.get(fingerprint) {
|
||||
return Some(PresenceLocation::Remote(rp.relay_addr));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Return all fingerprints connected directly to this relay.
|
||||
pub fn local_fingerprints(&self) -> HashSet<String> {
|
||||
self.local.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Return a full dump of every known fingerprint and its location.
|
||||
pub fn all_known(&self) -> Vec<(String, PresenceLocation)> {
|
||||
let mut out = Vec::new();
|
||||
for fp in self.local.keys() {
|
||||
out.push((fp.clone(), PresenceLocation::Local));
|
||||
}
|
||||
for (fp, rp) in &self.remote {
|
||||
// Skip if also local (local wins)
|
||||
if !self.local.contains_key(fp) {
|
||||
out.push((fp.clone(), PresenceLocation::Remote(rp.relay_addr)));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Remove remote entries older than `timeout`.
|
||||
pub fn expire_stale(&mut self, timeout: Duration) {
|
||||
let cutoff = Instant::now() - timeout;
|
||||
|
||||
// Expire remote presence entries
|
||||
self.remote.retain(|_, rp| rp.last_seen > cutoff);
|
||||
|
||||
// Expire peer relay records and their fingerprint sets
|
||||
let stale_peers: Vec<SocketAddr> = self.peers
|
||||
.iter()
|
||||
.filter(|(_, p)| p.last_update <= cutoff)
|
||||
.map(|(addr, _)| *addr)
|
||||
.collect();
|
||||
|
||||
for addr in stale_peers {
|
||||
self.peers.remove(&addr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the peer relay map (for HTTP API).
|
||||
pub fn peers(&self) -> &HashMap<SocketAddr, PeerRelay> {
|
||||
&self.peers
|
||||
}
|
||||
|
||||
/// Return a reference to the local presence map (for HTTP API).
|
||||
pub fn local_entries(&self) -> &HashMap<String, LocalPresence> {
|
||||
&self.local
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
fn addr(s: &str) -> SocketAddr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_and_lookup_local() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("aabbccdd", Some("alice".into()), Some("room1".into()));
|
||||
|
||||
assert_eq!(reg.lookup("aabbccdd"), Some(PresenceLocation::Local));
|
||||
// Unknown fingerprint returns None
|
||||
assert_eq!(reg.lookup("00000000"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_removes() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("aabbccdd", None, None);
|
||||
assert_eq!(reg.lookup("aabbccdd"), Some(PresenceLocation::Local));
|
||||
|
||||
reg.unregister_local("aabbccdd");
|
||||
assert_eq!(reg.lookup("aabbccdd"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_peer_and_lookup() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
let peer = addr("10.0.0.2:4433");
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("deadbeef".to_string());
|
||||
fps.insert("cafebabe".to_string());
|
||||
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
assert_eq!(reg.lookup("deadbeef"), Some(PresenceLocation::Remote(peer)));
|
||||
assert_eq!(reg.lookup("cafebabe"), Some(PresenceLocation::Remote(peer)));
|
||||
assert_eq!(reg.lookup("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expire_stale_removes_old() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
let peer = addr("10.0.0.3:4433");
|
||||
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("olduser".to_string());
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
// Verify it's there
|
||||
assert_eq!(reg.lookup("olduser"), Some(PresenceLocation::Remote(peer)));
|
||||
|
||||
// Manually backdate the last_seen and last_update
|
||||
if let Some(rp) = reg.remote.get_mut("olduser") {
|
||||
rp.last_seen = Instant::now() - Duration::from_secs(120);
|
||||
}
|
||||
if let Some(p) = reg.peers.get_mut(&peer) {
|
||||
p.last_update = Instant::now() - Duration::from_secs(120);
|
||||
}
|
||||
|
||||
// Expire with 60s timeout — should remove the 120s-old entries
|
||||
reg.expire_stale(Duration::from_secs(60));
|
||||
|
||||
assert_eq!(reg.lookup("olduser"), None);
|
||||
assert!(reg.peers.get(&peer).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_fingerprints_list() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("fp1", None, None);
|
||||
reg.register_local("fp2", Some("bob".into()), Some("room-a".into()));
|
||||
reg.register_local("fp3", None, None);
|
||||
|
||||
let fps = reg.local_fingerprints();
|
||||
assert_eq!(fps.len(), 3);
|
||||
assert!(fps.contains("fp1"));
|
||||
assert!(fps.contains("fp2"));
|
||||
assert!(fps.contains("fp3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_known_includes_local_and_remote() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("local1", None, None);
|
||||
|
||||
let peer = addr("10.0.0.5:4433");
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("remote1".to_string());
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
let all = reg.all_known();
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let local_entries: Vec<_> = all.iter()
|
||||
.filter(|(_, loc)| *loc == PresenceLocation::Local)
|
||||
.collect();
|
||||
assert_eq!(local_entries.len(), 1);
|
||||
assert_eq!(local_entries[0].0, "local1");
|
||||
|
||||
let remote_entries: Vec<_> = all.iter()
|
||||
.filter(|(_, loc)| matches!(loc, PresenceLocation::Remote(_)))
|
||||
.collect();
|
||||
assert_eq!(remote_entries.len(), 1);
|
||||
assert_eq!(remote_entries[0].0, "remote1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_overrides_remote_in_lookup() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
let peer = addr("10.0.0.6:4433");
|
||||
|
||||
// Register as remote first
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("dupfp".to_string());
|
||||
reg.update_peer(peer, fps);
|
||||
assert_eq!(reg.lookup("dupfp"), Some(PresenceLocation::Remote(peer)));
|
||||
|
||||
// Now register locally — local should win
|
||||
reg.register_local("dupfp", None, None);
|
||||
assert_eq!(reg.lookup("dupfp"), Some(PresenceLocation::Local));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_peer_replaces_old_fingerprints() {
|
||||
let mut reg = PresenceRegistry::new();
|
||||
let peer = addr("10.0.0.7:4433");
|
||||
|
||||
let mut fps1 = HashSet::new();
|
||||
fps1.insert("user_a".to_string());
|
||||
fps1.insert("user_b".to_string());
|
||||
reg.update_peer(peer, fps1);
|
||||
|
||||
assert_eq!(reg.lookup("user_a"), Some(PresenceLocation::Remote(peer)));
|
||||
assert_eq!(reg.lookup("user_b"), Some(PresenceLocation::Remote(peer)));
|
||||
|
||||
// Update with only user_b — user_a should be gone
|
||||
let mut fps2 = HashSet::new();
|
||||
fps2.insert("user_b".to_string());
|
||||
reg.update_peer(peer, fps2);
|
||||
|
||||
assert_eq!(reg.lookup("user_a"), None);
|
||||
assert_eq!(reg.lookup("user_b"), Some(PresenceLocation::Remote(peer)));
|
||||
}
|
||||
}
|
||||
632
crates/wzp-relay/src/probe.rs
Normal file
632
crates/wzp-relay/src/probe.rs
Normal file
@@ -0,0 +1,632 @@
|
||||
//! Inter-relay health probe.
|
||||
//!
|
||||
//! A `ProbeRunner` maintains a persistent QUIC connection to a peer relay,
|
||||
//! sends 1 Ping/s, and measures RTT, loss, and jitter. Results are exported
|
||||
//! as Prometheus gauges with a `target` label.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use prometheus::{Gauge, IntGauge, Opts, Registry};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use wzp_proto::{MediaTransport, SignalMessage};
|
||||
|
||||
/// Configuration for a single probe target.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ProbeConfig {
|
||||
pub target: SocketAddr,
|
||||
pub interval: Duration,
|
||||
}
|
||||
|
||||
impl ProbeConfig {
|
||||
pub fn new(target: SocketAddr) -> Self {
|
||||
Self {
|
||||
target,
|
||||
interval: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prometheus metrics for one probe target.
|
||||
pub struct ProbeMetrics {
|
||||
pub rtt_ms: Gauge,
|
||||
pub loss_pct: Gauge,
|
||||
pub jitter_ms: Gauge,
|
||||
pub up: IntGauge,
|
||||
}
|
||||
|
||||
impl ProbeMetrics {
|
||||
/// Register probe metrics with the given `target` label value.
|
||||
pub fn register(target: &str, registry: &Registry) -> Self {
|
||||
let rtt_ms = Gauge::with_opts(
|
||||
Opts::new("wzp_probe_rtt_ms", "RTT to peer relay in ms")
|
||||
.const_label("target", target),
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
let loss_pct = Gauge::with_opts(
|
||||
Opts::new("wzp_probe_loss_pct", "Packet loss to peer relay in %")
|
||||
.const_label("target", target),
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
let jitter_ms = Gauge::with_opts(
|
||||
Opts::new("wzp_probe_jitter_ms", "Jitter to peer relay in ms")
|
||||
.const_label("target", target),
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
let up = IntGauge::with_opts(
|
||||
Opts::new("wzp_probe_up", "1 if peer relay is reachable, 0 if not")
|
||||
.const_label("target", target),
|
||||
)
|
||||
.expect("probe metric");
|
||||
|
||||
registry.register(Box::new(rtt_ms.clone())).expect("register");
|
||||
registry.register(Box::new(loss_pct.clone())).expect("register");
|
||||
registry.register(Box::new(jitter_ms.clone())).expect("register");
|
||||
registry.register(Box::new(up.clone())).expect("register");
|
||||
|
||||
Self {
|
||||
rtt_ms,
|
||||
loss_pct,
|
||||
jitter_ms,
|
||||
up,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sliding window for tracking probe results over the last N pings.
|
||||
pub struct SlidingWindow {
|
||||
/// Capacity (number of pings to track).
|
||||
capacity: usize,
|
||||
/// Timestamps of sent pings (ms since epoch) in order.
|
||||
sent: VecDeque<u64>,
|
||||
/// RTT values for received pongs (ms). None = no pong received yet.
|
||||
rtts: VecDeque<Option<f64>>,
|
||||
}
|
||||
|
||||
impl SlidingWindow {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
capacity,
|
||||
sent: VecDeque::with_capacity(capacity),
|
||||
rtts: VecDeque::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a sent ping.
|
||||
pub fn record_sent(&mut self, timestamp_ms: u64) {
|
||||
if self.sent.len() >= self.capacity {
|
||||
self.sent.pop_front();
|
||||
self.rtts.pop_front();
|
||||
}
|
||||
self.sent.push_back(timestamp_ms);
|
||||
self.rtts.push_back(None);
|
||||
}
|
||||
|
||||
/// Record a received pong. Returns the computed RTT in ms, or None if
|
||||
/// the timestamp doesn't match any pending ping.
|
||||
pub fn record_pong(&mut self, timestamp_ms: u64, now_ms: u64) -> Option<f64> {
|
||||
// Find the sent ping with this timestamp
|
||||
for (i, &sent_ts) in self.sent.iter().enumerate() {
|
||||
if sent_ts == timestamp_ms {
|
||||
let rtt = (now_ms as f64) - (sent_ts as f64);
|
||||
self.rtts[i] = Some(rtt);
|
||||
return Some(rtt);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Compute loss percentage (0.0-100.0) from the current window.
|
||||
/// A ping is considered lost if it has no matching pong.
|
||||
pub fn loss_pct(&self) -> f64 {
|
||||
if self.sent.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let total = self.rtts.len() as f64;
|
||||
let lost = self.rtts.iter().filter(|r| r.is_none()).count() as f64;
|
||||
(lost / total) * 100.0
|
||||
}
|
||||
|
||||
/// Compute jitter as the standard deviation of RTT values (ms).
|
||||
/// Only considers pings that received a pong.
|
||||
pub fn jitter_ms(&self) -> f64 {
|
||||
let rtts: Vec<f64> = self.rtts.iter().filter_map(|r| *r).collect();
|
||||
if rtts.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
let mean = rtts.iter().sum::<f64>() / rtts.len() as f64;
|
||||
let variance = rtts.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / rtts.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
/// Return the most recent RTT value, if any.
|
||||
pub fn latest_rtt(&self) -> Option<f64> {
|
||||
self.rtts.iter().rev().find_map(|r| *r)
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a health probe against a single peer relay.
|
||||
pub struct ProbeRunner {
|
||||
config: ProbeConfig,
|
||||
metrics: ProbeMetrics,
|
||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||
}
|
||||
|
||||
impl ProbeRunner {
|
||||
/// Create a new probe runner, registering metrics with the given registry.
|
||||
pub fn new(
|
||||
config: ProbeConfig,
|
||||
registry: &Registry,
|
||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||
) -> Self {
|
||||
let target_str = config.target.to_string();
|
||||
let metrics = ProbeMetrics::register(&target_str, registry);
|
||||
Self { config, metrics, presence }
|
||||
}
|
||||
|
||||
/// Run the probe forever. This function never returns under normal operation.
|
||||
/// It connects to the target relay, sends Ping every `interval`, and processes
|
||||
/// Pong replies to compute RTT, loss, and jitter.
|
||||
pub async fn run(&self) -> ! {
|
||||
loop {
|
||||
info!(target = %self.config.target, "probe connecting...");
|
||||
match self.run_session().await {
|
||||
Ok(()) => {
|
||||
// Session ended cleanly (shouldn't happen in practice)
|
||||
warn!(target = %self.config.target, "probe session ended, reconnecting in 5s");
|
||||
}
|
||||
Err(e) => {
|
||||
error!(target = %self.config.target, "probe session error: {e}, reconnecting in 5s");
|
||||
}
|
||||
}
|
||||
self.metrics.up.set(0);
|
||||
self.metrics.rtt_ms.set(0.0);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one probe session (one QUIC connection). Returns when the connection drops.
|
||||
async fn run_session(&self) -> anyhow::Result<()> {
|
||||
// Create a client-only endpoint on an ephemeral port
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(
|
||||
&endpoint,
|
||||
self.config.target,
|
||||
"_probe",
|
||||
client_cfg,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
self.metrics.up.set(1);
|
||||
info!(target = %self.config.target, "probe connected");
|
||||
|
||||
let window = Arc::new(Mutex::new(SlidingWindow::new(60)));
|
||||
|
||||
// Spawn recv task for pong messages
|
||||
let recv_transport = transport.clone();
|
||||
let recv_window = window.clone();
|
||||
let rtt_gauge = self.metrics.rtt_ms.clone();
|
||||
let loss_gauge = self.metrics.loss_pct.clone();
|
||||
let jitter_gauge = self.metrics.jitter_ms.clone();
|
||||
let up_gauge = self.metrics.up.clone();
|
||||
|
||||
let recv_presence = self.presence.clone();
|
||||
let recv_target = self.config.target;
|
||||
let recv_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
match recv_transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::Pong { timestamp_ms })) => {
|
||||
let now_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
let mut w = recv_window.lock().await;
|
||||
if let Some(rtt) = w.record_pong(timestamp_ms, now_ms) {
|
||||
rtt_gauge.set(rtt);
|
||||
}
|
||||
loss_gauge.set(w.loss_pct());
|
||||
jitter_gauge.set(w.jitter_ms());
|
||||
}
|
||||
Ok(Some(SignalMessage::PresenceUpdate { fingerprints, relay_addr })) => {
|
||||
if let Some(ref reg) = recv_presence {
|
||||
// Parse the relay_addr; fall back to the connection target
|
||||
let addr = relay_addr.parse().unwrap_or(recv_target);
|
||||
let fps: std::collections::HashSet<String> = fingerprints.into_iter().collect();
|
||||
let mut r = reg.lock().await;
|
||||
r.update_peer(addr, fps);
|
||||
}
|
||||
}
|
||||
Ok(Some(_)) => {
|
||||
// Ignore other signals
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("probe recv: connection closed");
|
||||
up_gauge.set(0);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("probe recv error: {e}");
|
||||
up_gauge.set(0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Send ping loop (+ presence gossip every 10 pings)
|
||||
let mut interval = tokio::time::interval(self.config.interval);
|
||||
let mut ping_count: u64 = 0;
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if recv_handle.is_finished() {
|
||||
// Recv task died — connection is lost
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let timestamp_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64;
|
||||
|
||||
{
|
||||
let mut w = window.lock().await;
|
||||
w.record_sent(timestamp_ms);
|
||||
}
|
||||
|
||||
if let Err(e) = transport
|
||||
.send_signal(&SignalMessage::Ping { timestamp_ms })
|
||||
.await
|
||||
{
|
||||
error!(target = %self.config.target, "probe ping send error: {e}");
|
||||
recv_handle.abort();
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
// Send presence update every 10 pings (~10 seconds)
|
||||
ping_count += 1;
|
||||
if ping_count % 10 == 0 {
|
||||
if let Some(ref reg) = self.presence {
|
||||
let fps: Vec<String> = {
|
||||
let r = reg.lock().await;
|
||||
r.local_fingerprints().into_iter().collect()
|
||||
};
|
||||
let msg = SignalMessage::PresenceUpdate {
|
||||
fingerprints: fps,
|
||||
relay_addr: self.config.target.to_string(),
|
||||
};
|
||||
if let Err(e) = transport.send_signal(&msg).await {
|
||||
warn!(target = %self.config.target, "presence update send error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Coordinates multiple `ProbeRunner` instances for mesh mode.
|
||||
///
|
||||
/// Each relay probes all configured peers concurrently. The `ProbeMesh` owns the
|
||||
/// runners and spawns them as independent tokio tasks.
|
||||
pub struct ProbeMesh {
|
||||
runners: Vec<ProbeRunner>,
|
||||
}
|
||||
|
||||
impl ProbeMesh {
|
||||
/// Create a new mesh coordinator, registering metrics for every target.
|
||||
pub fn new(
|
||||
targets: Vec<SocketAddr>,
|
||||
registry: &Registry,
|
||||
presence: Option<Arc<tokio::sync::Mutex<crate::presence::PresenceRegistry>>>,
|
||||
) -> Self {
|
||||
let runners = targets
|
||||
.into_iter()
|
||||
.map(|addr| {
|
||||
let config = ProbeConfig::new(addr);
|
||||
ProbeRunner::new(config, registry, presence.clone())
|
||||
})
|
||||
.collect();
|
||||
Self { runners }
|
||||
}
|
||||
|
||||
/// Spawn all runners as concurrent tokio tasks. This consumes the mesh.
|
||||
pub async fn run_all(self) {
|
||||
let mut handles = Vec::with_capacity(self.runners.len());
|
||||
for runner in self.runners {
|
||||
let target = runner.config.target;
|
||||
info!(target = %target, "spawning mesh probe");
|
||||
handles.push(tokio::spawn(async move { runner.run().await }));
|
||||
}
|
||||
// Probes run forever; if we ever need to wait:
|
||||
for h in handles {
|
||||
let _ = h.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of probe targets in this mesh.
|
||||
pub fn target_count(&self) -> usize {
|
||||
self.runners.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a human-readable mesh health table from probe metrics in the registry.
|
||||
///
|
||||
/// Scans the registry for `wzp_probe_*` gauges and formats them into a table.
|
||||
pub fn mesh_summary(registry: &Registry) -> String {
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
let families = registry.gather();
|
||||
|
||||
// Collect per-target values: target -> (rtt, loss, jitter, up)
|
||||
let mut targets: BTreeMap<String, (f64, f64, f64, bool)> = BTreeMap::new();
|
||||
|
||||
for family in &families {
|
||||
let name = family.get_name();
|
||||
for metric in family.get_metric() {
|
||||
// Find the "target" label
|
||||
let target_label = metric
|
||||
.get_label()
|
||||
.iter()
|
||||
.find(|l| l.get_name() == "target");
|
||||
let target = match target_label {
|
||||
Some(l) => l.get_value().to_string(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let entry = targets.entry(target).or_insert((0.0, 0.0, 0.0, false));
|
||||
|
||||
match name {
|
||||
"wzp_probe_rtt_ms" => entry.0 = metric.get_gauge().get_value(),
|
||||
"wzp_probe_loss_pct" => entry.1 = metric.get_gauge().get_value(),
|
||||
"wzp_probe_jitter_ms" => entry.2 = metric.get_gauge().get_value(),
|
||||
"wzp_probe_up" => entry.3 = metric.get_gauge().get_value() as i64 == 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut out = String::new();
|
||||
out.push_str("Relay Mesh Health\n");
|
||||
out.push_str("\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\u{2500}\n");
|
||||
out.push_str(&format!(
|
||||
"{:<20} {:>6} {:>6} {:>7} {}\n",
|
||||
"Target", "RTT", "Loss", "Jitter", "Status"
|
||||
));
|
||||
|
||||
for (target, (rtt, loss, jitter, up)) in &targets {
|
||||
let status = if *up { "UP" } else { "DOWN" };
|
||||
out.push_str(&format!(
|
||||
"{:<20} {:>5.0}ms {:>5.1}% {:>5.0}ms {}\n",
|
||||
target, rtt, loss, jitter, status
|
||||
));
|
||||
}
|
||||
|
||||
if targets.is_empty() {
|
||||
out.push_str(" (no probe targets configured)\n");
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Handle an incoming Ping signal by replying with a Pong carrying the same timestamp.
|
||||
/// Returns true if the message was a Ping and was handled, false otherwise.
|
||||
pub async fn handle_ping(
|
||||
transport: &wzp_transport::QuinnTransport,
|
||||
msg: &SignalMessage,
|
||||
) -> bool {
|
||||
if let SignalMessage::Ping { timestamp_ms } = msg {
|
||||
if let Err(e) = transport
|
||||
.send_signal(&SignalMessage::Pong {
|
||||
timestamp_ms: *timestamp_ms,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send Pong reply: {e}");
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use prometheus::Encoder;
|
||||
|
||||
#[test]
|
||||
fn probe_metrics_register() {
|
||||
let registry = Registry::new();
|
||||
let _metrics = ProbeMetrics::register("127.0.0.1:4433", ®istry);
|
||||
// (ProbeRunner::new signature changed but this test only checks ProbeMetrics)
|
||||
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let families = registry.gather();
|
||||
let mut buf = Vec::new();
|
||||
encoder.encode(&families, &mut buf).unwrap();
|
||||
let output = String::from_utf8(buf).unwrap();
|
||||
|
||||
assert!(output.contains("wzp_probe_rtt_ms"), "missing wzp_probe_rtt_ms");
|
||||
assert!(output.contains("wzp_probe_loss_pct"), "missing wzp_probe_loss_pct");
|
||||
assert!(output.contains("wzp_probe_jitter_ms"), "missing wzp_probe_jitter_ms");
|
||||
assert!(output.contains("wzp_probe_up"), "missing wzp_probe_up");
|
||||
assert!(
|
||||
output.contains("target=\"127.0.0.1:4433\""),
|
||||
"missing target label"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rtt_calculation() {
|
||||
let mut window = SlidingWindow::new(60);
|
||||
|
||||
// Send a ping at t=1000
|
||||
window.record_sent(1000);
|
||||
// Receive pong at t=1050 => RTT = 50ms
|
||||
let rtt = window.record_pong(1000, 1050);
|
||||
assert_eq!(rtt, Some(50.0));
|
||||
|
||||
// Send at t=2000, receive at t=2030 => RTT = 30ms
|
||||
window.record_sent(2000);
|
||||
let rtt = window.record_pong(2000, 2030);
|
||||
assert_eq!(rtt, Some(30.0));
|
||||
|
||||
assert_eq!(window.latest_rtt(), Some(30.0));
|
||||
|
||||
// Unknown timestamp returns None
|
||||
let rtt = window.record_pong(9999, 10000);
|
||||
assert!(rtt.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_calculation() {
|
||||
let mut window = SlidingWindow::new(10);
|
||||
|
||||
// Send 10 pings
|
||||
for i in 0..10 {
|
||||
window.record_sent(i * 1000);
|
||||
}
|
||||
|
||||
// Receive pongs for 7 out of 10 (miss indices 2, 5, 8)
|
||||
for i in 0..10u64 {
|
||||
if i == 2 || i == 5 || i == 8 {
|
||||
continue; // lost
|
||||
}
|
||||
window.record_pong(i * 1000, i * 1000 + 40);
|
||||
}
|
||||
|
||||
// 3 out of 10 lost = 30%
|
||||
let loss = window.loss_pct();
|
||||
assert!((loss - 30.0).abs() < 0.01, "expected ~30%, got {loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jitter_calculation() {
|
||||
let mut window = SlidingWindow::new(10);
|
||||
|
||||
// Send 4 pings with known RTTs: 10, 20, 30, 40
|
||||
// Mean = 25, variance = ((15^2 + 5^2 + 5^2 + 15^2) / 4) = (225+25+25+225)/4 = 125
|
||||
// std dev = sqrt(125) ≈ 11.18
|
||||
let rtts = [10.0, 20.0, 30.0, 40.0];
|
||||
for (i, rtt) in rtts.iter().enumerate() {
|
||||
let sent = (i as u64) * 1000;
|
||||
window.record_sent(sent);
|
||||
window.record_pong(sent, sent + *rtt as u64);
|
||||
}
|
||||
|
||||
let jitter = window.jitter_ms();
|
||||
assert!(
|
||||
(jitter - 11.18).abs() < 0.1,
|
||||
"expected jitter ~11.18ms, got {jitter}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sliding_window_eviction() {
|
||||
let mut window = SlidingWindow::new(5);
|
||||
|
||||
// Fill window
|
||||
for i in 0..5 {
|
||||
window.record_sent(i * 1000);
|
||||
}
|
||||
assert_eq!(window.sent.len(), 5);
|
||||
|
||||
// Add one more — oldest should be evicted
|
||||
window.record_sent(5000);
|
||||
assert_eq!(window.sent.len(), 5);
|
||||
assert_eq!(*window.sent.front().unwrap(), 1000);
|
||||
|
||||
// All 5 are unanswered
|
||||
assert!((window.loss_pct() - 100.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_window_edge_cases() {
|
||||
let window = SlidingWindow::new(60);
|
||||
assert_eq!(window.loss_pct(), 0.0);
|
||||
assert_eq!(window.jitter_ms(), 0.0);
|
||||
assert!(window.latest_rtt().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mesh_creates_runners() {
|
||||
let registry = Registry::new();
|
||||
let targets: Vec<SocketAddr> = vec![
|
||||
"127.0.0.1:4433".parse().unwrap(),
|
||||
"127.0.0.2:4433".parse().unwrap(),
|
||||
"127.0.0.3:4433".parse().unwrap(),
|
||||
];
|
||||
let mesh = ProbeMesh::new(targets, ®istry, None);
|
||||
assert_eq!(mesh.target_count(), 3);
|
||||
|
||||
// Verify metrics were registered for each target
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let families = registry.gather();
|
||||
let mut buf = Vec::new();
|
||||
encoder.encode(&families, &mut buf).unwrap();
|
||||
let output = String::from_utf8(buf).unwrap();
|
||||
|
||||
assert!(output.contains("target=\"127.0.0.1:4433\""));
|
||||
assert!(output.contains("target=\"127.0.0.2:4433\""));
|
||||
assert!(output.contains("target=\"127.0.0.3:4433\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mesh_summary_empty() {
|
||||
let registry = Registry::new();
|
||||
let summary = mesh_summary(®istry);
|
||||
|
||||
// Should contain the header
|
||||
assert!(summary.contains("Relay Mesh Health"));
|
||||
assert!(summary.contains("Target"));
|
||||
assert!(summary.contains("RTT"));
|
||||
assert!(summary.contains("Loss"));
|
||||
assert!(summary.contains("Jitter"));
|
||||
assert!(summary.contains("Status"));
|
||||
// Should indicate no targets
|
||||
assert!(summary.contains("no probe targets configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mesh_summary_with_targets() {
|
||||
let registry = Registry::new();
|
||||
// Register probe metrics for two targets and set values
|
||||
let m1 = ProbeMetrics::register("relay-b:4433", ®istry);
|
||||
m1.rtt_ms.set(12.0);
|
||||
m1.loss_pct.set(0.0);
|
||||
m1.jitter_ms.set(2.0);
|
||||
m1.up.set(1);
|
||||
|
||||
let m2 = ProbeMetrics::register("relay-c:4433", ®istry);
|
||||
m2.rtt_ms.set(45.0);
|
||||
m2.loss_pct.set(0.1);
|
||||
m2.jitter_ms.set(5.0);
|
||||
m2.up.set(0);
|
||||
|
||||
let summary = mesh_summary(®istry);
|
||||
|
||||
assert!(summary.contains("relay-b:4433"));
|
||||
assert!(summary.contains("relay-c:4433"));
|
||||
assert!(summary.contains("UP"));
|
||||
assert!(summary.contains("DOWN"));
|
||||
// Should NOT contain "no probe targets"
|
||||
assert!(!summary.contains("no probe targets configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mesh_zero_targets() {
|
||||
let registry = Registry::new();
|
||||
let mesh = ProbeMesh::new(vec![], ®istry, None);
|
||||
assert_eq!(mesh.target_count(), 0);
|
||||
}
|
||||
}
|
||||
483
crates/wzp-relay/src/relay_link.rs
Normal file
483
crates/wzp-relay/src/relay_link.rs
Normal file
@@ -0,0 +1,483 @@
|
||||
//! Per-session relay forwarding — connect to a peer relay and forward only
|
||||
//! specific sessions' media packets there.
|
||||
//!
|
||||
//! This is the building block for relay chaining (multi-hop calls). Instead
|
||||
//! of forwarding ALL traffic to a single hardcoded relay (forward mode) or
|
||||
//! to everyone in a room (SFU mode), a `RelayLink` represents a QUIC
|
||||
//! connection to one peer relay used for forwarding a specific set of
|
||||
//! sessions.
|
||||
//!
|
||||
//! `RelayLinkManager` tracks all active relay links and their session
|
||||
//! assignments, providing get-or-connect semantics and idle cleanup.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use wzp_proto::MediaPacket;
|
||||
use wzp_proto::MediaTransport;
|
||||
|
||||
/// A connection to a peer relay for forwarding specific sessions.
|
||||
///
|
||||
/// Each `RelayLink` holds a QUIC transport to one peer relay and tracks
|
||||
/// which session IDs are being forwarded through it. When all sessions
|
||||
/// are removed the link is considered idle and can be cleaned up.
|
||||
pub struct RelayLink {
|
||||
target_addr: SocketAddr,
|
||||
/// The underlying QUIC transport. `None` only in unit-test stubs where
|
||||
/// no real connection is established.
|
||||
transport: Option<Arc<wzp_transport::QuinnTransport>>,
|
||||
active_sessions: HashSet<String>,
|
||||
}
|
||||
|
||||
impl RelayLink {
|
||||
/// Connect to a peer relay at `target`.
|
||||
///
|
||||
/// Uses the `"_relay"` SNI to signal that this is a relay-to-relay
|
||||
/// connection (similar to `"_probe"` for health checks). The peer
|
||||
/// should skip normal client auth/handshake for relay-SNI connections.
|
||||
pub async fn connect(target: SocketAddr) -> Result<Self, anyhow::Error> {
|
||||
// Create a client-only endpoint on an OS-assigned port.
|
||||
let endpoint = wzp_transport::create_endpoint(
|
||||
"0.0.0.0:0".parse().unwrap(),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(&endpoint, target, "_relay", client_cfg).await?;
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
|
||||
info!(%target, "relay link established");
|
||||
|
||||
Ok(Self {
|
||||
target_addr: target,
|
||||
transport: Some(transport),
|
||||
active_sessions: HashSet::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a `RelayLink` from an existing transport (useful when the
|
||||
/// connection was established through other means).
|
||||
pub fn from_transport(
|
||||
target_addr: SocketAddr,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
) -> Self {
|
||||
Self {
|
||||
target_addr,
|
||||
transport: Some(transport),
|
||||
active_sessions: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a stub `RelayLink` with no transport — for unit tests that
|
||||
/// only exercise session-tracking / management logic.
|
||||
#[cfg(test)]
|
||||
fn stub(target_addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
target_addr,
|
||||
transport: None,
|
||||
active_sessions: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward a media packet to this peer relay.
|
||||
pub async fn forward(&self, pkt: &MediaPacket) -> Result<(), anyhow::Error> {
|
||||
match &self.transport {
|
||||
Some(t) => t
|
||||
.send_media(pkt)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("relay link forward to {}: {e}", self.target_addr)),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"relay link to {} has no transport (stub)",
|
||||
self.target_addr
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// The address of the peer relay this link connects to.
|
||||
pub fn target_addr(&self) -> SocketAddr {
|
||||
self.target_addr
|
||||
}
|
||||
|
||||
/// A reference to the underlying QUIC transport (if connected).
|
||||
pub fn transport(&self) -> Option<&Arc<wzp_transport::QuinnTransport>> {
|
||||
self.transport.as_ref()
|
||||
}
|
||||
|
||||
/// Add a session to be forwarded through this link.
|
||||
pub fn add_session(&mut self, session_id: &str) {
|
||||
if self.active_sessions.insert(session_id.to_string()) {
|
||||
debug!(
|
||||
target_relay = %self.target_addr,
|
||||
session = session_id,
|
||||
count = self.active_sessions.len(),
|
||||
"session added to relay link"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a session from this link.
|
||||
pub fn remove_session(&mut self, session_id: &str) {
|
||||
if self.active_sessions.remove(session_id) {
|
||||
debug!(
|
||||
target_relay = %self.target_addr,
|
||||
session = session_id,
|
||||
count = self.active_sessions.len(),
|
||||
"session removed from relay link"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this link is forwarding any sessions.
|
||||
pub fn is_idle(&self) -> bool {
|
||||
self.active_sessions.is_empty()
|
||||
}
|
||||
|
||||
/// Number of sessions being forwarded through this link.
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.active_sessions.len()
|
||||
}
|
||||
|
||||
/// Check if a specific session is being forwarded through this link.
|
||||
pub fn has_session(&self, session_id: &str) -> bool {
|
||||
self.active_sessions.contains(session_id)
|
||||
}
|
||||
|
||||
/// Close the underlying QUIC connection (no-op if no transport).
|
||||
pub async fn close(&self) {
|
||||
info!(target_relay = %self.target_addr, "closing relay link");
|
||||
if let Some(ref t) = self.transport {
|
||||
let _ = t.close().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RelayLinkManager
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Manages connections to multiple peer relays for per-session forwarding.
|
||||
///
|
||||
/// Each peer relay gets at most one `RelayLink`. Sessions are registered
|
||||
/// on specific links, and idle links (no sessions) can be cleaned up.
|
||||
pub struct RelayLinkManager {
|
||||
links: HashMap<SocketAddr, RelayLink>,
|
||||
}
|
||||
|
||||
impl RelayLinkManager {
|
||||
/// Create an empty link manager.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
links: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create a link to a peer relay.
|
||||
///
|
||||
/// If a link already exists it is returned. Otherwise a new QUIC
|
||||
/// connection is established using `RelayLink::connect`.
|
||||
pub async fn get_or_connect(
|
||||
&mut self,
|
||||
target: SocketAddr,
|
||||
) -> Result<&RelayLink, anyhow::Error> {
|
||||
if !self.links.contains_key(&target) {
|
||||
let link = RelayLink::connect(target).await?;
|
||||
self.links.insert(target, link);
|
||||
}
|
||||
Ok(self.links.get(&target).unwrap())
|
||||
}
|
||||
|
||||
/// Get a mutable reference to an existing link (if any).
|
||||
pub fn get_mut(&mut self, target: &SocketAddr) -> Option<&mut RelayLink> {
|
||||
self.links.get_mut(target)
|
||||
}
|
||||
|
||||
/// Get a reference to an existing link (if any).
|
||||
pub fn get(&self, target: &SocketAddr) -> Option<&RelayLink> {
|
||||
self.links.get(target)
|
||||
}
|
||||
|
||||
/// Forward a packet for a specific session to the appropriate relay.
|
||||
///
|
||||
/// The link must already exist (created via `get_or_connect`).
|
||||
pub async fn forward_to(
|
||||
&self,
|
||||
target: SocketAddr,
|
||||
pkt: &MediaPacket,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
match self.links.get(&target) {
|
||||
Some(link) => link.forward(pkt).await,
|
||||
None => Err(anyhow::anyhow!(
|
||||
"no relay link to {target} — call get_or_connect first"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a session on a specific link.
|
||||
///
|
||||
/// The link must already exist. If it does not, a warning is logged
|
||||
/// and the registration is silently skipped.
|
||||
pub fn register_session(&mut self, target: SocketAddr, session_id: &str) {
|
||||
match self.links.get_mut(&target) {
|
||||
Some(link) => link.add_session(session_id),
|
||||
None => {
|
||||
warn!(
|
||||
%target,
|
||||
session = session_id,
|
||||
"cannot register session — no link to target"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unregister a session. If the link becomes idle, close and remove it.
|
||||
pub async fn unregister_session(&mut self, target: SocketAddr, session_id: &str) {
|
||||
let should_remove = if let Some(link) = self.links.get_mut(&target) {
|
||||
link.remove_session(session_id);
|
||||
if link.is_idle() {
|
||||
link.close().await;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if should_remove {
|
||||
self.links.remove(&target);
|
||||
info!(%target, "idle relay link removed");
|
||||
}
|
||||
}
|
||||
|
||||
/// Close all links and clear the manager.
|
||||
pub async fn close_all(&mut self) {
|
||||
for (addr, link) in self.links.drain() {
|
||||
info!(%addr, "closing relay link (shutdown)");
|
||||
link.close().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of active links.
|
||||
pub fn link_count(&self) -> usize {
|
||||
self.links.len()
|
||||
}
|
||||
|
||||
/// Total number of sessions being forwarded across all links.
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.links.values().map(|l| l.session_count()).sum()
|
||||
}
|
||||
|
||||
/// Insert a pre-built relay link (for testing or manual setup).
|
||||
pub fn insert(&mut self, link: RelayLink) {
|
||||
self.links.insert(link.target_addr(), link);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn addr(s: &str) -> SocketAddr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
// ---------- RelayLink session tracking ----------
|
||||
|
||||
#[test]
|
||||
fn link_manager_tracks_sessions() {
|
||||
let mut mgr = RelayLinkManager::new();
|
||||
let target1 = addr("10.0.0.2:4433");
|
||||
|
||||
let mut link = RelayLink::stub(target1);
|
||||
link.add_session("session-aaa");
|
||||
link.add_session("session-bbb");
|
||||
mgr.insert(link);
|
||||
|
||||
assert_eq!(mgr.link_count(), 1);
|
||||
assert_eq!(mgr.session_count(), 2);
|
||||
|
||||
// Register another session on the same link
|
||||
mgr.register_session(target1, "session-ccc");
|
||||
assert_eq!(mgr.session_count(), 3);
|
||||
|
||||
// Verify individual link
|
||||
let link_ref = mgr.get(&target1).unwrap();
|
||||
assert!(link_ref.has_session("session-aaa"));
|
||||
assert!(link_ref.has_session("session-bbb"));
|
||||
assert!(link_ref.has_session("session-ccc"));
|
||||
assert!(!link_ref.has_session("unknown"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_manager_idle_detection() {
|
||||
let mut link = RelayLink::stub(addr("10.0.0.3:4433"));
|
||||
|
||||
// Empty link is idle
|
||||
assert!(link.is_idle());
|
||||
assert_eq!(link.session_count(), 0);
|
||||
|
||||
// Add a session — no longer idle
|
||||
link.add_session("sess-1");
|
||||
assert!(!link.is_idle());
|
||||
assert_eq!(link.session_count(), 1);
|
||||
|
||||
// Remove it — idle again
|
||||
link.remove_session("sess-1");
|
||||
assert!(link.is_idle());
|
||||
assert_eq!(link.session_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_forward_signal_roundtrip() {
|
||||
use wzp_proto::SignalMessage;
|
||||
|
||||
// SessionForward roundtrip
|
||||
let msg = SignalMessage::SessionForward {
|
||||
session_id: "abcd1234".to_string(),
|
||||
target_fingerprint: "deadbeef".to_string(),
|
||||
source_relay: "10.0.0.1:4433".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::SessionForward {
|
||||
session_id,
|
||||
target_fingerprint,
|
||||
source_relay,
|
||||
} => {
|
||||
assert_eq!(session_id, "abcd1234");
|
||||
assert_eq!(target_fingerprint, "deadbeef");
|
||||
assert_eq!(source_relay, "10.0.0.1:4433");
|
||||
}
|
||||
_ => panic!("expected SessionForward variant"),
|
||||
}
|
||||
|
||||
// SessionForwardAck roundtrip
|
||||
let ack = SignalMessage::SessionForwardAck {
|
||||
session_id: "abcd1234".to_string(),
|
||||
room_name: "relay-room-42".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&ack).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
SignalMessage::SessionForwardAck {
|
||||
session_id,
|
||||
room_name,
|
||||
} => {
|
||||
assert_eq!(session_id, "abcd1234");
|
||||
assert_eq!(room_name, "relay-room-42");
|
||||
}
|
||||
_ => panic!("expected SessionForwardAck variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_manager_multi_target() {
|
||||
let mut mgr = RelayLinkManager::new();
|
||||
let target_a = addr("10.0.0.2:4433");
|
||||
let target_b = addr("10.0.0.3:4433");
|
||||
let target_c = addr("10.0.0.4:4433");
|
||||
|
||||
for (target, sessions) in [
|
||||
(target_a, vec!["s1", "s2"]),
|
||||
(target_b, vec!["s3"]),
|
||||
(target_c, vec!["s4", "s5", "s6"]),
|
||||
] {
|
||||
let mut link = RelayLink::stub(target);
|
||||
for s in sessions {
|
||||
link.add_session(s);
|
||||
}
|
||||
mgr.insert(link);
|
||||
}
|
||||
|
||||
assert_eq!(mgr.link_count(), 3);
|
||||
assert_eq!(mgr.session_count(), 6); // 2 + 1 + 3
|
||||
|
||||
assert_eq!(mgr.get(&target_a).unwrap().session_count(), 2);
|
||||
assert_eq!(mgr.get(&target_b).unwrap().session_count(), 1);
|
||||
assert_eq!(mgr.get(&target_c).unwrap().session_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn link_manager_cleanup() {
|
||||
let mut mgr = RelayLinkManager::new();
|
||||
let target = addr("10.0.0.5:4433");
|
||||
|
||||
let mut link = RelayLink::stub(target);
|
||||
link.add_session("s1");
|
||||
link.add_session("s2");
|
||||
link.add_session("s3");
|
||||
mgr.insert(link);
|
||||
|
||||
assert_eq!(mgr.link_count(), 1);
|
||||
assert_eq!(mgr.session_count(), 3);
|
||||
|
||||
// Remove sessions one by one via the manager's mutable access.
|
||||
// We cannot call the async unregister_session with stub links here,
|
||||
// so we exercise the synchronous management path directly.
|
||||
{
|
||||
let link = mgr.get_mut(&target).unwrap();
|
||||
link.remove_session("s1");
|
||||
assert!(!link.is_idle());
|
||||
link.remove_session("s2");
|
||||
assert!(!link.is_idle());
|
||||
link.remove_session("s3");
|
||||
assert!(link.is_idle());
|
||||
}
|
||||
|
||||
// All sessions removed — link is idle
|
||||
assert_eq!(mgr.session_count(), 0);
|
||||
assert!(mgr.get(&target).unwrap().is_idle());
|
||||
|
||||
// Simulate what unregister_session does: remove the idle link
|
||||
mgr.links.remove(&target);
|
||||
assert_eq!(mgr.link_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_session_on_nonexistent_link_is_noop() {
|
||||
let mut mgr = RelayLinkManager::new();
|
||||
// Should not panic, just warn
|
||||
mgr.register_session(addr("10.0.0.99:4433"), "orphan-session");
|
||||
assert_eq!(mgr.link_count(), 0);
|
||||
assert_eq!(mgr.session_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_to_nonexistent_link_errors() {
|
||||
let mgr = RelayLinkManager::new();
|
||||
let target = addr("10.0.0.99:4433");
|
||||
|
||||
let pkt = MediaPacket {
|
||||
header: wzp_proto::packet::MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: wzp_proto::CodecId::Opus16k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq: 1,
|
||||
timestamp: 100,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: bytes::Bytes::from_static(b"test"),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.build()
|
||||
.unwrap();
|
||||
let result = rt.block_on(mgr.forward_to(target, &pkt));
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("no relay link"));
|
||||
}
|
||||
}
|
||||
@@ -3,15 +3,21 @@
|
||||
//! Each room holds N participants. When one participant sends a media packet,
|
||||
//! the relay forwards it to all other participants in the room (SFU model).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use wzp_proto::packet::TrunkFrame;
|
||||
use wzp_proto::MediaTransport;
|
||||
|
||||
use crate::metrics::RelayMetrics;
|
||||
use crate::trunk::TrunkBatcher;
|
||||
|
||||
/// Unique participant ID within a room.
|
||||
pub type ParticipantId = u64;
|
||||
|
||||
@@ -21,11 +27,64 @@ fn next_id() -> ParticipantId {
|
||||
NEXT_PARTICIPANT_ID.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// How to send data to a participant — either via QUIC transport or WebSocket channel.
|
||||
#[derive(Clone)]
|
||||
pub enum ParticipantSender {
|
||||
Quic(Arc<wzp_transport::QuinnTransport>),
|
||||
WebSocket(tokio::sync::mpsc::Sender<Bytes>),
|
||||
}
|
||||
|
||||
impl ParticipantSender {
|
||||
/// Send raw bytes to this participant.
|
||||
pub async fn send_raw(&self, data: &[u8]) -> Result<(), String> {
|
||||
match self {
|
||||
ParticipantSender::WebSocket(tx) => {
|
||||
tx.try_send(Bytes::copy_from_slice(data))
|
||||
.map_err(|e| format!("ws send: {e}"))
|
||||
}
|
||||
ParticipantSender::Quic(transport) => {
|
||||
let pkt = wzp_proto::MediaPacket {
|
||||
header: wzp_proto::packet::MediaHeader::default_pcm(),
|
||||
payload: Bytes::copy_from_slice(data),
|
||||
quality_report: None,
|
||||
};
|
||||
transport.send_media(&pkt).await.map_err(|e| format!("quic send: {e}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a QUIC participant.
|
||||
pub fn is_quic(&self) -> bool {
|
||||
matches!(self, ParticipantSender::Quic(_))
|
||||
}
|
||||
|
||||
/// Get the QUIC transport if this is a QUIC participant.
|
||||
pub fn as_quic(&self) -> Option<&Arc<wzp_transport::QuinnTransport>> {
|
||||
match self {
|
||||
ParticipantSender::Quic(t) => Some(t),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast a signal message to a list of participant senders.
|
||||
pub async fn broadcast_signal(senders: &[ParticipantSender], msg: &wzp_proto::SignalMessage) {
|
||||
for sender in senders {
|
||||
if let ParticipantSender::Quic(t) = sender {
|
||||
if let Err(e) = t.send_signal(msg).await {
|
||||
warn!("broadcast_signal error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A participant in a room.
|
||||
struct Participant {
|
||||
id: ParticipantId,
|
||||
addr: std::net::SocketAddr,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
_addr: std::net::SocketAddr,
|
||||
sender: ParticipantSender,
|
||||
fingerprint: Option<String>,
|
||||
alias: Option<String>,
|
||||
}
|
||||
|
||||
/// A room holding multiple participants.
|
||||
@@ -40,10 +99,16 @@ impl Room {
|
||||
}
|
||||
}
|
||||
|
||||
fn add(&mut self, addr: std::net::SocketAddr, transport: Arc<wzp_transport::QuinnTransport>) -> ParticipantId {
|
||||
fn add(
|
||||
&mut self,
|
||||
addr: std::net::SocketAddr,
|
||||
sender: ParticipantSender,
|
||||
fingerprint: Option<String>,
|
||||
alias: Option<String>,
|
||||
) -> ParticipantId {
|
||||
let id = next_id();
|
||||
info!(room_size = self.participants.len() + 1, participant = id, %addr, "joined room");
|
||||
self.participants.push(Participant { id, addr, transport });
|
||||
self.participants.push(Participant { id, _addr: addr, sender, fingerprint, alias });
|
||||
id
|
||||
}
|
||||
|
||||
@@ -52,14 +117,41 @@ impl Room {
|
||||
info!(room_size = self.participants.len(), participant = id, "left room");
|
||||
}
|
||||
|
||||
fn others(&self, exclude_id: ParticipantId) -> Vec<Arc<wzp_transport::QuinnTransport>> {
|
||||
fn others(&self, exclude_id: ParticipantId) -> Vec<ParticipantSender> {
|
||||
self.participants
|
||||
.iter()
|
||||
.filter(|p| p.id != exclude_id)
|
||||
.map(|p| p.transport.clone())
|
||||
.map(|p| p.sender.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build a RoomUpdate participant list.
|
||||
fn participant_list(&self) -> Vec<wzp_proto::packet::RoomParticipant> {
|
||||
self.participants
|
||||
.iter()
|
||||
.map(|p| wzp_proto::packet::RoomParticipant {
|
||||
fingerprint: p.fingerprint.clone().unwrap_or_default(),
|
||||
alias: p.alias.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all senders (for broadcasting to everyone including the joiner).
|
||||
fn all_senders(&self) -> Vec<ParticipantSender> {
|
||||
self.participants.iter().map(|p| p.sender.clone()).collect()
|
||||
}
|
||||
|
||||
/// Update a participant's alias. Returns true if the participant was found.
|
||||
fn set_alias(&mut self, id: ParticipantId, alias: String) -> bool {
|
||||
if let Some(p) = self.participants.iter_mut().find(|p| p.id == id) {
|
||||
info!(participant = id, %alias, "alias updated");
|
||||
p.alias = Some(alias);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.participants.is_empty()
|
||||
}
|
||||
@@ -72,43 +164,134 @@ impl Room {
|
||||
/// Manages all rooms on the relay.
|
||||
pub struct RoomManager {
|
||||
rooms: HashMap<String, Room>,
|
||||
/// Room access control list. Maps hashed room name → allowed fingerprints.
|
||||
/// When `None`, rooms are open (no auth mode). When `Some`, only listed
|
||||
/// fingerprints can join the corresponding room.
|
||||
acl: Option<HashMap<String, HashSet<String>>>,
|
||||
}
|
||||
|
||||
impl RoomManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
rooms: HashMap::new(),
|
||||
acl: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Join a room. Returns the participant ID.
|
||||
/// Create a room manager with ACL enforcement enabled.
|
||||
pub fn with_acl() -> Self {
|
||||
Self {
|
||||
rooms: HashMap::new(),
|
||||
acl: Some(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant a fingerprint access to a room.
|
||||
pub fn allow(&mut self, room_name: &str, fingerprint: &str) {
|
||||
if let Some(ref mut acl) = self.acl {
|
||||
acl.entry(room_name.to_string())
|
||||
.or_default()
|
||||
.insert(fingerprint.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a fingerprint is authorized to join a room.
|
||||
/// Returns true if ACL is disabled (open mode) or the fingerprint is in the allow list.
|
||||
pub fn is_authorized(&self, room_name: &str, fingerprint: Option<&str>) -> bool {
|
||||
match (&self.acl, fingerprint) {
|
||||
(None, _) => true, // no ACL = open
|
||||
(Some(_), None) => false, // ACL enabled but no fingerprint
|
||||
(Some(acl), Some(fp)) => {
|
||||
// Room not in ACL = open room (allow anyone authenticated)
|
||||
match acl.get(room_name) {
|
||||
None => true,
|
||||
Some(allowed) => allowed.contains(fp),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Join a room. Returns (participant_id, room_update_msg, all_senders) for broadcasting.
|
||||
pub fn join(
|
||||
&mut self,
|
||||
room_name: &str,
|
||||
addr: std::net::SocketAddr,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
) -> ParticipantId {
|
||||
sender: ParticipantSender,
|
||||
fingerprint: Option<&str>,
|
||||
alias: Option<&str>,
|
||||
) -> Result<(ParticipantId, wzp_proto::SignalMessage, Vec<ParticipantSender>), String> {
|
||||
if !self.is_authorized(room_name, fingerprint) {
|
||||
warn!(room = room_name, fingerprint = ?fingerprint, "unauthorized room join attempt");
|
||||
return Err("not authorized for this room".to_string());
|
||||
}
|
||||
let room = self.rooms.entry(room_name.to_string()).or_insert_with(Room::new);
|
||||
room.add(addr, transport)
|
||||
let id = room.add(addr, sender, fingerprint.map(|s| s.to_string()), alias.map(|s| s.to_string()));
|
||||
let update = wzp_proto::SignalMessage::RoomUpdate {
|
||||
count: room.len() as u32,
|
||||
participants: room.participant_list(),
|
||||
};
|
||||
let senders = room.all_senders();
|
||||
Ok((id, update, senders))
|
||||
}
|
||||
|
||||
/// Leave a room. Removes the room if empty.
|
||||
pub fn leave(&mut self, room_name: &str, participant_id: ParticipantId) {
|
||||
/// Join a room via WebSocket. Convenience wrapper around `join()`.
|
||||
pub fn join_ws(
|
||||
&mut self,
|
||||
room_name: &str,
|
||||
addr: std::net::SocketAddr,
|
||||
sender: tokio::sync::mpsc::Sender<Bytes>,
|
||||
fingerprint: Option<&str>,
|
||||
) -> Result<ParticipantId, String> {
|
||||
let (id, _update, _senders) = self.join(room_name, addr, ParticipantSender::WebSocket(sender), fingerprint, None)?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Leave a room. Returns (room_update_msg, remaining_senders) for broadcasting, or None if room is now empty.
|
||||
pub fn leave(&mut self, room_name: &str, participant_id: ParticipantId) -> Option<(wzp_proto::SignalMessage, Vec<ParticipantSender>)> {
|
||||
if let Some(room) = self.rooms.get_mut(room_name) {
|
||||
room.remove(participant_id);
|
||||
if room.is_empty() {
|
||||
self.rooms.remove(room_name);
|
||||
info!(room = room_name, "room closed (empty)");
|
||||
return None;
|
||||
}
|
||||
let update = wzp_proto::SignalMessage::RoomUpdate {
|
||||
count: room.len() as u32,
|
||||
participants: room.participant_list(),
|
||||
};
|
||||
let senders = room.all_senders();
|
||||
Some((update, senders))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get transports for all OTHER participants in a room.
|
||||
/// Update a participant's alias and return a RoomUpdate + senders for broadcasting.
|
||||
pub fn set_alias(
|
||||
&mut self,
|
||||
room_name: &str,
|
||||
participant_id: ParticipantId,
|
||||
alias: String,
|
||||
) -> Option<(wzp_proto::SignalMessage, Vec<ParticipantSender>)> {
|
||||
if let Some(room) = self.rooms.get_mut(room_name) {
|
||||
if room.set_alias(participant_id, alias) {
|
||||
let update = wzp_proto::SignalMessage::RoomUpdate {
|
||||
count: room.len() as u32,
|
||||
participants: room.participant_list(),
|
||||
};
|
||||
let senders = room.all_senders();
|
||||
return Some((update, senders));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get senders for all OTHER participants in a room.
|
||||
pub fn others(
|
||||
&self,
|
||||
room_name: &str,
|
||||
participant_id: ParticipantId,
|
||||
) -> Vec<Arc<wzp_transport::QuinnTransport>> {
|
||||
) -> Vec<ParticipantSender> {
|
||||
self.rooms
|
||||
.get(room_name)
|
||||
.map(|r| r.others(participant_id))
|
||||
@@ -126,64 +309,472 @@ impl RoomManager {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrunkedForwarder — wraps a transport and batches outgoing media into trunk
|
||||
// frames so multiple packets ride a single QUIC datagram.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wraps a [`QuinnTransport`] with a [`TrunkBatcher`] so that small media
|
||||
/// packets are accumulated and sent together in a single QUIC datagram.
|
||||
pub struct TrunkedForwarder {
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
batcher: TrunkBatcher,
|
||||
session_id: [u8; 2],
|
||||
}
|
||||
|
||||
impl TrunkedForwarder {
|
||||
/// Create a new trunked forwarder.
|
||||
///
|
||||
/// `session_id` tags every entry pushed into the batcher so the receiver
|
||||
/// can demultiplex packets by session.
|
||||
pub fn new(transport: Arc<wzp_transport::QuinnTransport>, session_id: [u8; 2]) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
batcher: TrunkBatcher::new(),
|
||||
session_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a media packet into the batcher. If the batcher is full it will
|
||||
/// flush automatically and the resulting trunk frame is sent immediately.
|
||||
pub async fn send(&mut self, pkt: &wzp_proto::MediaPacket) -> anyhow::Result<()> {
|
||||
let payload: Bytes = pkt.to_bytes();
|
||||
if let Some(frame) = self.batcher.push(self.session_id, payload) {
|
||||
self.send_frame(&frame)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush any pending packets — called on the 5 ms timer tick.
|
||||
pub async fn flush(&mut self) -> anyhow::Result<()> {
|
||||
if let Some(frame) = self.batcher.flush() {
|
||||
self.send_frame(&frame)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the flush interval configured on the inner batcher.
|
||||
pub fn flush_interval(&self) -> Duration {
|
||||
self.batcher.flush_interval
|
||||
}
|
||||
|
||||
fn send_frame(&self, frame: &TrunkFrame) -> anyhow::Result<()> {
|
||||
self.transport.send_trunk(frame).map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// run_participant — the hot-path forwarding loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Run the receive loop for one participant in a room.
|
||||
/// Forwards all received packets to every other participant.
|
||||
///
|
||||
/// When `trunking_enabled` is true, outgoing packets are accumulated per-peer
|
||||
/// into [`TrunkedForwarder`]s and flushed every 5 ms or when the batcher is
|
||||
/// full, reducing QUIC datagram overhead.
|
||||
pub async fn run_participant(
|
||||
room_mgr: Arc<Mutex<RoomManager>>,
|
||||
room_name: String,
|
||||
participant_id: ParticipantId,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
metrics: Arc<RelayMetrics>,
|
||||
session_id: &str,
|
||||
trunking_enabled: bool,
|
||||
) {
|
||||
if trunking_enabled {
|
||||
run_participant_trunked(
|
||||
room_mgr, room_name, participant_id, transport, metrics, session_id,
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
run_participant_plain(
|
||||
room_mgr, room_name, participant_id, transport, metrics, session_id,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Plain (non-trunked) forwarding loop — original behaviour.
|
||||
async fn run_participant_plain(
|
||||
room_mgr: Arc<Mutex<RoomManager>>,
|
||||
room_name: String,
|
||||
participant_id: ParticipantId,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
metrics: Arc<RelayMetrics>,
|
||||
session_id: &str,
|
||||
) {
|
||||
let addr = transport.connection().remote_address();
|
||||
let mut packets_forwarded = 0u64;
|
||||
|
||||
loop {
|
||||
let pkt = match transport.recv_media().await {
|
||||
Ok(Some(pkt)) => pkt,
|
||||
Ok(None) => {
|
||||
info!(%addr, participant = participant_id, "disconnected");
|
||||
break;
|
||||
// Media forwarding task (with debug logging from Android fixes)
|
||||
let media_room_mgr = room_mgr.clone();
|
||||
let media_room_name = room_name.clone();
|
||||
let media_transport = transport.clone();
|
||||
let media_metrics = metrics.clone();
|
||||
let media_session_id = session_id.to_string();
|
||||
let media_task = async move {
|
||||
let mut packets_forwarded = 0u64;
|
||||
let mut last_recv_instant = std::time::Instant::now();
|
||||
let mut max_recv_gap_ms = 0u64;
|
||||
let mut max_forward_ms = 0u64;
|
||||
let mut send_errors = 0u64;
|
||||
let mut last_log_instant = std::time::Instant::now();
|
||||
|
||||
info!(
|
||||
room = %media_room_name,
|
||||
participant = participant_id,
|
||||
%addr,
|
||||
session = %media_session_id,
|
||||
"forwarding loop started (plain)"
|
||||
);
|
||||
|
||||
loop {
|
||||
let pkt = match media_transport.recv_media().await {
|
||||
Ok(Some(pkt)) => pkt,
|
||||
Ok(None) => {
|
||||
info!(%addr, participant = participant_id, forwarded = packets_forwarded, "disconnected (stream ended)");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
if msg.contains("timed out") || msg.contains("reset") || msg.contains("closed") {
|
||||
info!(%addr, participant = participant_id, forwarded = packets_forwarded, "connection closed: {e}");
|
||||
} else {
|
||||
error!(%addr, participant = participant_id, forwarded = packets_forwarded, "recv error: {e}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let recv_gap_ms = last_recv_instant.elapsed().as_millis() as u64;
|
||||
last_recv_instant = std::time::Instant::now();
|
||||
if recv_gap_ms > max_recv_gap_ms {
|
||||
max_recv_gap_ms = recv_gap_ms;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, participant = participant_id, "recv error: {e}");
|
||||
break;
|
||||
if recv_gap_ms > 200 {
|
||||
warn!(
|
||||
room = %media_room_name,
|
||||
participant = participant_id,
|
||||
recv_gap_ms,
|
||||
seq = pkt.header.seq,
|
||||
"large recv gap"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Get current list of other participants
|
||||
let others = {
|
||||
let mgr = room_mgr.lock().await;
|
||||
mgr.others(&room_name, participant_id)
|
||||
};
|
||||
if let Some(ref report) = pkt.quality_report {
|
||||
media_metrics.update_session_quality(&media_session_id, report);
|
||||
}
|
||||
|
||||
// Forward to all others
|
||||
for other in &others {
|
||||
// Best-effort: if one send fails, continue to others
|
||||
if let Err(e) = other.send_media(&pkt).await {
|
||||
// Don't log every failure — they'll be cleaned up when their recv loop breaks
|
||||
let _ = e;
|
||||
let lock_start = std::time::Instant::now();
|
||||
let others = {
|
||||
let mgr = media_room_mgr.lock().await;
|
||||
mgr.others(&media_room_name, participant_id)
|
||||
};
|
||||
let lock_ms = lock_start.elapsed().as_millis() as u64;
|
||||
if lock_ms > 10 {
|
||||
warn!(room = %media_room_name, participant = participant_id, lock_ms, "slow room_mgr lock");
|
||||
}
|
||||
|
||||
let fwd_start = std::time::Instant::now();
|
||||
let pkt_bytes = pkt.payload.len() as u64;
|
||||
for other in &others {
|
||||
match other {
|
||||
ParticipantSender::Quic(t) => {
|
||||
if let Err(e) = t.send_media(&pkt).await {
|
||||
send_errors += 1;
|
||||
if send_errors <= 5 || send_errors % 100 == 0 {
|
||||
warn!(
|
||||
room = %media_room_name,
|
||||
participant = participant_id,
|
||||
peer = %t.connection().remote_address(),
|
||||
total_send_errors = send_errors,
|
||||
"send_media error: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
ParticipantSender::WebSocket(_) => {
|
||||
let _ = other.send_raw(&pkt.payload).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
let fwd_ms = fwd_start.elapsed().as_millis() as u64;
|
||||
if fwd_ms > max_forward_ms { max_forward_ms = fwd_ms; }
|
||||
if fwd_ms > 50 {
|
||||
warn!(room = %media_room_name, participant = participant_id, fwd_ms, fan_out = others.len(), "slow forward");
|
||||
}
|
||||
|
||||
let fan_out = others.len() as u64;
|
||||
media_metrics.packets_forwarded.inc_by(fan_out);
|
||||
media_metrics.bytes_forwarded.inc_by(pkt_bytes * fan_out);
|
||||
packets_forwarded += 1;
|
||||
|
||||
if last_log_instant.elapsed() >= Duration::from_secs(5) {
|
||||
let room_size = {
|
||||
let mgr = media_room_mgr.lock().await;
|
||||
mgr.room_size(&media_room_name)
|
||||
};
|
||||
info!(
|
||||
room = %media_room_name,
|
||||
participant = participant_id,
|
||||
forwarded = packets_forwarded,
|
||||
room_size, fan_out, max_recv_gap_ms, max_forward_ms, send_errors,
|
||||
"participant stats"
|
||||
);
|
||||
max_recv_gap_ms = 0;
|
||||
max_forward_ms = 0;
|
||||
last_log_instant = std::time::Instant::now();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
packets_forwarded += 1;
|
||||
if packets_forwarded % 500 == 0 {
|
||||
let room_size = {
|
||||
let mgr = room_mgr.lock().await;
|
||||
mgr.room_size(&room_name)
|
||||
};
|
||||
info!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
forwarded = packets_forwarded,
|
||||
room_size,
|
||||
"participant stats"
|
||||
);
|
||||
// Signal handling task — processes SetAlias and other in-call signals
|
||||
let signal_room_mgr = room_mgr.clone();
|
||||
let signal_room_name = room_name.clone();
|
||||
let signal_transport = transport.clone();
|
||||
let signal_task = async move {
|
||||
loop {
|
||||
match signal_transport.recv_signal().await {
|
||||
Ok(Some(wzp_proto::SignalMessage::SetAlias { alias })) => {
|
||||
info!(%addr, participant = participant_id, %alias, "SetAlias received");
|
||||
let mut mgr = signal_room_mgr.lock().await;
|
||||
if let Some((update, senders)) =
|
||||
mgr.set_alias(&signal_room_name, participant_id, alias)
|
||||
{
|
||||
drop(mgr);
|
||||
broadcast_signal(&senders, &update).await;
|
||||
}
|
||||
}
|
||||
Ok(Some(wzp_proto::SignalMessage::Hangup { .. })) => {
|
||||
info!(%addr, participant = participant_id, "hangup received");
|
||||
break;
|
||||
}
|
||||
Ok(Some(msg)) => {
|
||||
info!(%addr, participant = participant_id, "signal: {:?}", std::mem::discriminant(&msg));
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(e) => {
|
||||
warn!(%addr, participant = participant_id, "signal recv error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Run both in parallel — exit when either finishes (disconnection)
|
||||
tokio::select! {
|
||||
_ = media_task => {}
|
||||
_ = signal_task => {}
|
||||
}
|
||||
|
||||
// Clean up — leave room and broadcast update to remaining participants
|
||||
let mut mgr = room_mgr.lock().await;
|
||||
if let Some((update, senders)) = mgr.leave(&room_name, participant_id) {
|
||||
drop(mgr); // release lock before async broadcast
|
||||
broadcast_signal(&senders, &update).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Trunked forwarding loop — batches outgoing packets per peer.
|
||||
async fn run_participant_trunked(
|
||||
room_mgr: Arc<Mutex<RoomManager>>,
|
||||
room_name: String,
|
||||
participant_id: ParticipantId,
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
metrics: Arc<RelayMetrics>,
|
||||
session_id: &str,
|
||||
) {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let addr = transport.connection().remote_address();
|
||||
let mut packets_forwarded = 0u64;
|
||||
let mut last_recv_instant = std::time::Instant::now();
|
||||
let mut max_recv_gap_ms = 0u64;
|
||||
let mut max_forward_ms = 0u64;
|
||||
let mut send_errors = 0u64;
|
||||
let mut last_log_instant = std::time::Instant::now();
|
||||
|
||||
info!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
%addr,
|
||||
session = session_id,
|
||||
"forwarding loop started (trunked)"
|
||||
);
|
||||
|
||||
// Per-peer TrunkedForwarders, keyed by the raw pointer of the peer
|
||||
// transport (stable for the Arc's lifetime). We use the remote address
|
||||
// string as the key since it is unique per connection.
|
||||
let mut forwarders: HashMap<std::net::SocketAddr, TrunkedForwarder> = HashMap::new();
|
||||
|
||||
// Derive a 2-byte session tag from the session_id hex string.
|
||||
let sid_bytes: [u8; 2] = parse_session_id_bytes(session_id);
|
||||
|
||||
let mut flush_interval = tokio::time::interval(Duration::from_millis(5));
|
||||
// Don't let missed ticks pile up — skip them and move on.
|
||||
flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
result = transport.recv_media() => {
|
||||
let pkt = match result {
|
||||
Ok(Some(pkt)) => pkt,
|
||||
Ok(None) => {
|
||||
info!(%addr, participant = participant_id, forwarded = packets_forwarded, "disconnected (stream ended)");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(%addr, participant = participant_id, forwarded = packets_forwarded, "recv error: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let recv_gap_ms = last_recv_instant.elapsed().as_millis() as u64;
|
||||
last_recv_instant = std::time::Instant::now();
|
||||
if recv_gap_ms > max_recv_gap_ms {
|
||||
max_recv_gap_ms = recv_gap_ms;
|
||||
}
|
||||
if recv_gap_ms > 200 {
|
||||
warn!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
recv_gap_ms,
|
||||
seq = pkt.header.seq,
|
||||
"large recv gap (trunked)"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(ref report) = pkt.quality_report {
|
||||
metrics.update_session_quality(session_id, report);
|
||||
}
|
||||
|
||||
let lock_start = std::time::Instant::now();
|
||||
let others = {
|
||||
let mgr = room_mgr.lock().await;
|
||||
mgr.others(&room_name, participant_id)
|
||||
};
|
||||
let lock_ms = lock_start.elapsed().as_millis() as u64;
|
||||
if lock_ms > 10 {
|
||||
warn!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
lock_ms,
|
||||
"slow room_mgr lock (trunked)"
|
||||
);
|
||||
}
|
||||
|
||||
let fwd_start = std::time::Instant::now();
|
||||
let pkt_bytes = pkt.payload.len() as u64;
|
||||
for other in &others {
|
||||
match other {
|
||||
ParticipantSender::Quic(t) => {
|
||||
let peer_addr = t.connection().remote_address();
|
||||
let fwd = forwarders
|
||||
.entry(peer_addr)
|
||||
.or_insert_with(|| TrunkedForwarder::new(t.clone(), sid_bytes));
|
||||
if let Err(e) = fwd.send(&pkt).await {
|
||||
send_errors += 1;
|
||||
if send_errors <= 5 || send_errors % 100 == 0 {
|
||||
warn!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
peer = %peer_addr,
|
||||
total_send_errors = send_errors,
|
||||
"trunked send error: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
ParticipantSender::WebSocket(_) => {
|
||||
let _ = other.send_raw(&pkt.payload).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
let fwd_ms = fwd_start.elapsed().as_millis() as u64;
|
||||
if fwd_ms > max_forward_ms {
|
||||
max_forward_ms = fwd_ms;
|
||||
}
|
||||
if fwd_ms > 50 {
|
||||
warn!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
fwd_ms,
|
||||
fan_out = others.len(),
|
||||
"slow forward (trunked)"
|
||||
);
|
||||
}
|
||||
|
||||
let fan_out = others.len() as u64;
|
||||
metrics.packets_forwarded.inc_by(fan_out);
|
||||
metrics.bytes_forwarded.inc_by(pkt_bytes * fan_out);
|
||||
packets_forwarded += 1;
|
||||
|
||||
// Periodic stats every 5 seconds
|
||||
if last_log_instant.elapsed() >= Duration::from_secs(5) {
|
||||
let room_size = {
|
||||
let mgr = room_mgr.lock().await;
|
||||
mgr.room_size(&room_name)
|
||||
};
|
||||
info!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
forwarded = packets_forwarded,
|
||||
room_size,
|
||||
fan_out,
|
||||
max_recv_gap_ms,
|
||||
max_forward_ms,
|
||||
send_errors,
|
||||
"participant stats (trunked)"
|
||||
);
|
||||
max_recv_gap_ms = 0;
|
||||
max_forward_ms = 0;
|
||||
last_log_instant = std::time::Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
_ = flush_interval.tick() => {
|
||||
for fwd in forwarders.values_mut() {
|
||||
if let Err(e) = fwd.flush().await {
|
||||
send_errors += 1;
|
||||
if send_errors <= 5 || send_errors % 100 == 0 {
|
||||
warn!(
|
||||
room = %room_name,
|
||||
participant = participant_id,
|
||||
total_send_errors = send_errors,
|
||||
"trunk flush error: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
// Final flush — send any remaining buffered packets.
|
||||
for fwd in forwarders.values_mut() {
|
||||
let _ = fwd.flush().await;
|
||||
}
|
||||
|
||||
let mut mgr = room_mgr.lock().await;
|
||||
mgr.leave(&room_name, participant_id);
|
||||
if let Some((update, senders)) = mgr.leave(&room_name, participant_id) {
|
||||
drop(mgr);
|
||||
broadcast_signal(&senders, &update).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse up to the first 2 bytes of a hex session-id string into `[u8; 2]`.
|
||||
fn parse_session_id_bytes(session_id: &str) -> [u8; 2] {
|
||||
let bytes: Vec<u8> = (0..session_id.len())
|
||||
.step_by(2)
|
||||
.filter_map(|i| u8::from_str_radix(session_id.get(i..i + 2)?, 16).ok())
|
||||
.collect();
|
||||
let mut out = [0u8; 2];
|
||||
for (i, b) in bytes.iter().take(2).enumerate() {
|
||||
out[i] = *b;
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -193,8 +784,125 @@ mod tests {
|
||||
#[test]
|
||||
fn room_join_leave() {
|
||||
let mut mgr = RoomManager::new();
|
||||
// Can't test with real transports, but test the room logic
|
||||
assert_eq!(mgr.room_size("test"), 0);
|
||||
assert!(mgr.list().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acl_open_mode_allows_all() {
|
||||
let mgr = RoomManager::new();
|
||||
assert!(mgr.is_authorized("any-room", None));
|
||||
assert!(mgr.is_authorized("any-room", Some("abc")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acl_enforced_requires_fingerprint() {
|
||||
let mgr = RoomManager::with_acl();
|
||||
assert!(!mgr.is_authorized("room1", None));
|
||||
// Room not in ACL = open to any authenticated user
|
||||
assert!(mgr.is_authorized("room1", Some("abc")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acl_restricts_to_allowed() {
|
||||
let mut mgr = RoomManager::with_acl();
|
||||
mgr.allow("room1", "alice");
|
||||
mgr.allow("room1", "bob");
|
||||
assert!(mgr.is_authorized("room1", Some("alice")));
|
||||
assert!(mgr.is_authorized("room1", Some("bob")));
|
||||
assert!(!mgr.is_authorized("room1", Some("eve")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_session_id_bytes_works() {
|
||||
assert_eq!(parse_session_id_bytes("abcd"), [0xab, 0xcd]);
|
||||
assert_eq!(parse_session_id_bytes("ff00"), [0xff, 0x00]);
|
||||
assert_eq!(parse_session_id_bytes(""), [0x00, 0x00]);
|
||||
// Longer hex strings: only first 2 bytes taken
|
||||
assert_eq!(parse_session_id_bytes("aabbccdd"), [0xaa, 0xbb]);
|
||||
}
|
||||
|
||||
/// Helper: create a minimal MediaPacket with the given payload bytes.
|
||||
fn make_test_packet(payload: &[u8]) -> wzp_proto::MediaPacket {
|
||||
wzp_proto::MediaPacket {
|
||||
header: wzp_proto::packet::MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
codec_id: wzp_proto::CodecId::Opus16k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
seq: 1,
|
||||
timestamp: 100,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(payload.to_vec()),
|
||||
quality_report: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push 3 packets into a batcher (simulating TrunkedForwarder.send),
|
||||
/// then flush and verify all 3 appear in a single TrunkFrame.
|
||||
#[test]
|
||||
fn trunked_forwarder_batches() {
|
||||
let session_id: [u8; 2] = [0x00, 0x01];
|
||||
let mut batcher = TrunkBatcher::new();
|
||||
// Ensure max_entries is high enough that 3 packets don't auto-flush.
|
||||
batcher.max_entries = 10;
|
||||
batcher.max_bytes = 4096;
|
||||
|
||||
let pkts = [
|
||||
make_test_packet(b"aaa"),
|
||||
make_test_packet(b"bbb"),
|
||||
make_test_packet(b"ccc"),
|
||||
];
|
||||
|
||||
for pkt in &pkts {
|
||||
let payload = pkt.to_bytes();
|
||||
let flushed = batcher.push(session_id, payload);
|
||||
// Should NOT auto-flush — we are below max_entries.
|
||||
assert!(flushed.is_none(), "unexpected auto-flush");
|
||||
}
|
||||
|
||||
// Explicit flush (simulates the 5 ms timer tick).
|
||||
let frame = batcher.flush().expect("expected a frame with 3 entries");
|
||||
assert_eq!(frame.len(), 3);
|
||||
for entry in &frame.packets {
|
||||
assert_eq!(entry.session_id, session_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Push exactly max_entries packets and verify the batcher auto-flushes
|
||||
/// on the last push (simulating TrunkedForwarder.send triggering a send).
|
||||
#[test]
|
||||
fn trunked_forwarder_auto_flushes() {
|
||||
let session_id: [u8; 2] = [0x00, 0x02];
|
||||
let mut batcher = TrunkBatcher::new();
|
||||
batcher.max_entries = 5;
|
||||
batcher.max_bytes = 8192;
|
||||
|
||||
let pkt = make_test_packet(b"hello");
|
||||
let mut auto_flushed: Option<wzp_proto::packet::TrunkFrame> = None;
|
||||
|
||||
for i in 0..5 {
|
||||
let payload = pkt.to_bytes();
|
||||
if let Some(frame) = batcher.push(session_id, payload) {
|
||||
assert!(auto_flushed.is_none(), "should auto-flush exactly once");
|
||||
auto_flushed = Some(frame);
|
||||
// The auto-flush should happen on the 5th push (max_entries = 5).
|
||||
assert_eq!(i, 4, "expected auto-flush on the last push");
|
||||
}
|
||||
}
|
||||
|
||||
let frame = auto_flushed.expect("batcher should have auto-flushed at max_entries");
|
||||
assert_eq!(frame.len(), 5);
|
||||
for entry in &frame.packets {
|
||||
assert_eq!(entry.session_id, session_id);
|
||||
}
|
||||
|
||||
// Batcher should now be empty — nothing to flush.
|
||||
assert!(batcher.flush().is_none());
|
||||
}
|
||||
}
|
||||
|
||||
265
crates/wzp-relay/src/route.rs
Normal file
265
crates/wzp-relay/src/route.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
//! Route resolution — given a target fingerprint, find the relay chain
|
||||
//! needed to reach that user.
|
||||
//!
|
||||
//! Uses the [`PresenceRegistry`] as its data source. Currently supports
|
||||
//! single-hop resolution (local or direct peer). The `resolve_multi_hop`
|
||||
//! method has the signature for future multi-hop expansion but falls back
|
||||
//! to single-hop for now.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::presence::{PresenceLocation, PresenceRegistry};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Route type
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The resolved route to a target fingerprint.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
|
||||
pub enum Route {
|
||||
/// Target is connected to this relay directly.
|
||||
Local,
|
||||
/// Target is on a directly connected peer relay.
|
||||
DirectPeer(SocketAddr),
|
||||
/// Target is reachable via a chain of relays (multi-hop).
|
||||
Chain(Vec<SocketAddr>),
|
||||
/// Target not found in any known relay.
|
||||
NotFound,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Route {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Route::Local => write!(f, "local"),
|
||||
Route::DirectPeer(addr) => write!(f, "direct_peer({})", addr),
|
||||
Route::Chain(chain) => {
|
||||
let addrs: Vec<String> = chain.iter().map(|a| a.to_string()).collect();
|
||||
write!(f, "chain({})", addrs.join(" -> "))
|
||||
}
|
||||
Route::NotFound => write!(f, "not_found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RouteResolver
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolves fingerprints to relay routes using the presence registry.
|
||||
pub struct RouteResolver {
|
||||
/// Our own relay address (how peers know us).
|
||||
local_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl RouteResolver {
|
||||
/// Create a new route resolver for the relay at `local_addr`.
|
||||
pub fn new(local_addr: SocketAddr) -> Self {
|
||||
Self { local_addr }
|
||||
}
|
||||
|
||||
/// Our local relay address.
|
||||
pub fn local_addr(&self) -> SocketAddr {
|
||||
self.local_addr
|
||||
}
|
||||
|
||||
/// Look up a fingerprint in the registry and return the route.
|
||||
///
|
||||
/// - If `registry.lookup()` returns `Local` -> `Route::Local`
|
||||
/// - If returns `Remote(addr)` -> `Route::DirectPeer(addr)`
|
||||
/// - If not found -> `Route::NotFound`
|
||||
pub fn resolve(&self, registry: &PresenceRegistry, target_fingerprint: &str) -> Route {
|
||||
match registry.lookup(target_fingerprint) {
|
||||
Some(PresenceLocation::Local) => Route::Local,
|
||||
Some(PresenceLocation::Remote(addr)) => Route::DirectPeer(addr),
|
||||
None => Route::NotFound,
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-hop route resolution (future expansion).
|
||||
///
|
||||
/// For now this is equivalent to `resolve()` — single-hop only.
|
||||
/// When multi-hop is implemented, this will query peers transitively
|
||||
/// up to `max_hops` relays deep, using `RouteQuery` / `RouteResponse`
|
||||
/// signals over probe connections.
|
||||
pub fn resolve_multi_hop(
|
||||
&self,
|
||||
registry: &PresenceRegistry,
|
||||
target: &str,
|
||||
_max_hops: usize,
|
||||
) -> Route {
|
||||
// Phase 1: single-hop only (same as resolve).
|
||||
// Future: if resolve returns NotFound and max_hops > 0,
|
||||
// send RouteQuery to each known peer with ttl = max_hops - 1,
|
||||
// collect RouteResponse, and build a Chain.
|
||||
self.resolve(registry, target)
|
||||
}
|
||||
|
||||
/// Build a JSON-serializable route response for the HTTP API.
|
||||
pub fn route_json(
|
||||
&self,
|
||||
fingerprint: &str,
|
||||
route: &Route,
|
||||
) -> serde_json::Value {
|
||||
let (route_type, relay_chain) = match route {
|
||||
Route::Local => ("local", vec![self.local_addr.to_string()]),
|
||||
Route::DirectPeer(addr) => ("direct_peer", vec![self.local_addr.to_string(), addr.to_string()]),
|
||||
Route::Chain(chain) => {
|
||||
let mut addrs = vec![self.local_addr.to_string()];
|
||||
addrs.extend(chain.iter().map(|a| a.to_string()));
|
||||
("chain", addrs)
|
||||
}
|
||||
Route::NotFound => ("not_found", vec![]),
|
||||
};
|
||||
|
||||
serde_json::json!({
|
||||
"fingerprint": fingerprint,
|
||||
"route": route_type,
|
||||
"relay_chain": relay_chain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
fn addr(s: &str) -> SocketAddr {
|
||||
s.parse().unwrap()
|
||||
}
|
||||
|
||||
fn make_resolver() -> RouteResolver {
|
||||
RouteResolver::new(addr("10.0.0.1:4433"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_local() {
|
||||
let resolver = make_resolver();
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("aabbccdd", Some("alice".into()), Some("room1".into()));
|
||||
|
||||
let route = resolver.resolve(®, "aabbccdd");
|
||||
assert_eq!(route, Route::Local);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_direct_peer() {
|
||||
let resolver = make_resolver();
|
||||
let mut reg = PresenceRegistry::new();
|
||||
let peer = addr("10.0.0.2:4433");
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("deadbeef".to_string());
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
let route = resolver.resolve(®, "deadbeef");
|
||||
assert_eq!(route, Route::DirectPeer(peer));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_not_found() {
|
||||
let resolver = make_resolver();
|
||||
let reg = PresenceRegistry::new();
|
||||
|
||||
let route = resolver.resolve(®, "unknown_fp");
|
||||
assert_eq!(route, Route::NotFound);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_multi_hop_fallback() {
|
||||
// multi-hop currently falls back to single-hop behavior
|
||||
let resolver = make_resolver();
|
||||
let mut reg = PresenceRegistry::new();
|
||||
reg.register_local("local_fp", None, None);
|
||||
|
||||
let peer = addr("10.0.0.3:4433");
|
||||
let mut fps = HashSet::new();
|
||||
fps.insert("remote_fp".to_string());
|
||||
reg.update_peer(peer, fps);
|
||||
|
||||
// Local lookup works via multi-hop
|
||||
assert_eq!(resolver.resolve_multi_hop(®, "local_fp", 3), Route::Local);
|
||||
// Remote lookup works via multi-hop
|
||||
assert_eq!(
|
||||
resolver.resolve_multi_hop(®, "remote_fp", 3),
|
||||
Route::DirectPeer(peer)
|
||||
);
|
||||
// Not-found works via multi-hop
|
||||
assert_eq!(
|
||||
resolver.resolve_multi_hop(®, "nobody", 3),
|
||||
Route::NotFound
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_query_signal_roundtrip() {
|
||||
use wzp_proto::SignalMessage;
|
||||
|
||||
let query = SignalMessage::RouteQuery {
|
||||
fingerprint: "aabbccdd".to_string(),
|
||||
ttl: 3,
|
||||
};
|
||||
let json = serde_json::to_string(&query).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(
|
||||
decoded,
|
||||
SignalMessage::RouteQuery { ref fingerprint, ttl }
|
||||
if fingerprint == "aabbccdd" && ttl == 3
|
||||
));
|
||||
|
||||
let response = SignalMessage::RouteResponse {
|
||||
fingerprint: "aabbccdd".to_string(),
|
||||
found: true,
|
||||
relay_chain: vec!["10.0.0.1:4433".to_string(), "10.0.0.2:4433".to_string()],
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let decoded: SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(
|
||||
decoded,
|
||||
SignalMessage::RouteResponse { ref fingerprint, found, ref relay_chain }
|
||||
if fingerprint == "aabbccdd" && found && relay_chain.len() == 2
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_display() {
|
||||
assert_eq!(Route::Local.to_string(), "local");
|
||||
assert_eq!(
|
||||
Route::DirectPeer(addr("10.0.0.2:4433")).to_string(),
|
||||
"direct_peer(10.0.0.2:4433)"
|
||||
);
|
||||
assert_eq!(
|
||||
Route::Chain(vec![addr("10.0.0.2:4433"), addr("10.0.0.3:4433")]).to_string(),
|
||||
"chain(10.0.0.2:4433 -> 10.0.0.3:4433)"
|
||||
);
|
||||
assert_eq!(Route::NotFound.to_string(), "not_found");
|
||||
|
||||
// Debug is also useful
|
||||
let debug = format!("{:?}", Route::Local);
|
||||
assert!(debug.contains("Local"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_json_output() {
|
||||
let resolver = make_resolver();
|
||||
|
||||
let json = resolver.route_json("fp1", &Route::Local);
|
||||
assert_eq!(json["route"], "local");
|
||||
assert_eq!(json["fingerprint"], "fp1");
|
||||
assert_eq!(json["relay_chain"].as_array().unwrap().len(), 1);
|
||||
|
||||
let json = resolver.route_json("fp2", &Route::DirectPeer(addr("10.0.0.2:4433")));
|
||||
assert_eq!(json["route"], "direct_peer");
|
||||
assert_eq!(json["relay_chain"].as_array().unwrap().len(), 2);
|
||||
|
||||
let json = resolver.route_json("fp3", &Route::NotFound);
|
||||
assert_eq!(json["route"], "not_found");
|
||||
assert_eq!(json["relay_chain"].as_array().unwrap().len(), 0);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Session manager — tracks active call sessions on the relay.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use wzp_proto::{QualityProfile, Session};
|
||||
|
||||
@@ -9,6 +10,26 @@ use crate::pipeline::{PipelineConfig, RelayPipeline};
|
||||
/// Unique identifier for a relay session.
|
||||
pub type SessionId = [u8; 16];
|
||||
|
||||
/// Lifecycle state of a concurrent session.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SessionState {
|
||||
Active,
|
||||
Closing,
|
||||
}
|
||||
|
||||
/// Lightweight metadata for a concurrent session (room-mode tracking).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionInfo {
|
||||
/// Which room this session belongs to.
|
||||
pub room_name: String,
|
||||
/// Client fingerprint (present when auth is enabled).
|
||||
pub fingerprint: Option<String>,
|
||||
/// When the session was created.
|
||||
pub connected_at: Instant,
|
||||
/// Current lifecycle state.
|
||||
pub state: SessionState,
|
||||
}
|
||||
|
||||
/// A single active call session on the relay.
|
||||
pub struct RelaySession {
|
||||
/// Protocol session state machine.
|
||||
@@ -47,8 +68,14 @@ impl RelaySession {
|
||||
}
|
||||
|
||||
/// Manages all active sessions on a relay.
|
||||
///
|
||||
/// Combines two layers of tracking:
|
||||
/// - `sessions`: heavy `RelaySession` objects (pipeline state machines, used in forward mode)
|
||||
/// - `tracked`: lightweight `SessionInfo` entries (room + fingerprint, used in room mode to
|
||||
/// enforce `max_sessions` and answer lifecycle queries)
|
||||
pub struct SessionManager {
|
||||
sessions: HashMap<SessionId, RelaySession>,
|
||||
tracked: HashMap<SessionId, SessionInfo>,
|
||||
max_sessions: usize,
|
||||
}
|
||||
|
||||
@@ -56,17 +83,20 @@ impl SessionManager {
|
||||
pub fn new(max_sessions: usize) -> Self {
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
tracked: HashMap::new(),
|
||||
max_sessions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session. Returns None if at capacity.
|
||||
pub fn create_session(
|
||||
// ── Heavy session API (forward-mode pipelines) ──────────────────────
|
||||
|
||||
/// Create a new pipeline session. Returns None if at capacity.
|
||||
pub fn create_pipeline_session(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
config: PipelineConfig,
|
||||
) -> Option<&mut RelaySession> {
|
||||
if self.sessions.len() >= self.max_sessions {
|
||||
if self.total_count() >= self.max_sessions {
|
||||
return None;
|
||||
}
|
||||
self.sessions
|
||||
@@ -75,53 +105,124 @@ impl SessionManager {
|
||||
self.sessions.get_mut(&session_id)
|
||||
}
|
||||
|
||||
/// Get a session by ID.
|
||||
/// Get a pipeline session by ID.
|
||||
pub fn get_session(&mut self, id: &SessionId) -> Option<&mut RelaySession> {
|
||||
self.sessions.get_mut(id)
|
||||
}
|
||||
|
||||
/// Remove a session.
|
||||
pub fn remove_session(&mut self, id: &SessionId) -> Option<RelaySession> {
|
||||
/// Remove a pipeline session.
|
||||
pub fn remove_pipeline_session(&mut self, id: &SessionId) -> Option<RelaySession> {
|
||||
self.sessions.remove(id)
|
||||
}
|
||||
|
||||
/// Number of active sessions.
|
||||
pub fn active_count(&self) -> usize {
|
||||
/// Number of active pipeline sessions.
|
||||
pub fn pipeline_active_count(&self) -> usize {
|
||||
self.sessions.values().filter(|s| s.is_active()).count()
|
||||
}
|
||||
|
||||
/// Total sessions (including inactive/closing).
|
||||
pub fn total_count(&self) -> usize {
|
||||
/// Total pipeline sessions (including inactive/closing).
|
||||
pub fn pipeline_total_count(&self) -> usize {
|
||||
self.sessions.len()
|
||||
}
|
||||
|
||||
/// Remove sessions idle for longer than `timeout_ms`.
|
||||
/// Remove pipeline sessions idle for longer than `timeout_ms`.
|
||||
pub fn expire_idle(&mut self, now_ms: u64, timeout_ms: u64) -> usize {
|
||||
let before = self.sessions.len();
|
||||
self.sessions
|
||||
.retain(|_, s| now_ms.saturating_sub(s.last_activity_ms) < timeout_ms);
|
||||
before - self.sessions.len()
|
||||
}
|
||||
|
||||
// ── Lightweight concurrent-session API (room mode) ──────────────────
|
||||
|
||||
/// Register a new concurrent session.
|
||||
/// Returns the `SessionId` on success, or an error string if `max_sessions` is exceeded.
|
||||
pub fn create_session(
|
||||
&mut self,
|
||||
room: &str,
|
||||
fingerprint: Option<String>,
|
||||
) -> Result<SessionId, String> {
|
||||
if self.total_count() >= self.max_sessions {
|
||||
return Err(format!(
|
||||
"max sessions ({}) exceeded",
|
||||
self.max_sessions
|
||||
));
|
||||
}
|
||||
let id = rand_session_id();
|
||||
self.tracked.insert(id, SessionInfo {
|
||||
room_name: room.to_string(),
|
||||
fingerprint,
|
||||
connected_at: Instant::now(),
|
||||
state: SessionState::Active,
|
||||
});
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Remove a tracked session.
|
||||
pub fn remove_session(&mut self, id: SessionId) {
|
||||
self.tracked.remove(&id);
|
||||
}
|
||||
|
||||
/// Number of currently tracked (room-mode) sessions.
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.tracked.values().filter(|s| s.state == SessionState::Active).count()
|
||||
}
|
||||
|
||||
/// Return all session IDs that belong to a given room.
|
||||
pub fn sessions_in_room(&self, room: &str) -> Vec<SessionId> {
|
||||
self.tracked
|
||||
.iter()
|
||||
.filter(|(_, info)| info.room_name == room)
|
||||
.map(|(id, _)| *id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get metadata for a tracked session.
|
||||
pub fn session_info(&self, id: SessionId) -> Option<&SessionInfo> {
|
||||
self.tracked.get(&id)
|
||||
}
|
||||
|
||||
/// Total sessions across both tracking layers.
|
||||
pub fn total_count(&self) -> usize {
|
||||
self.sessions.len() + self.tracked.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random 16-byte session identifier.
|
||||
fn rand_session_id() -> SessionId {
|
||||
let mut id = [0u8; 16];
|
||||
// Use a simple monotonic + random source to avoid pulling in `rand` crate.
|
||||
// Hash the instant + a counter for uniqueness.
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
static CTR: AtomicU64 = AtomicU64::new(1);
|
||||
let ctr = CTR.fetch_add(1, Ordering::Relaxed);
|
||||
let bytes = ctr.to_le_bytes();
|
||||
id[..8].copy_from_slice(&bytes);
|
||||
// Mix in some time-based entropy for the upper half.
|
||||
let t = Instant::now().elapsed().as_nanos() as u64;
|
||||
id[8..16].copy_from_slice(&t.to_le_bytes());
|
||||
id
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Pipeline session tests (pre-existing, adapted to renamed API) ───
|
||||
|
||||
#[test]
|
||||
fn create_and_get_session() {
|
||||
fn create_and_get_pipeline_session() {
|
||||
let mut mgr = SessionManager::new(10);
|
||||
let id = [1u8; 16];
|
||||
mgr.create_session(id, PipelineConfig::default());
|
||||
assert_eq!(mgr.total_count(), 1);
|
||||
mgr.create_pipeline_session(id, PipelineConfig::default());
|
||||
assert!(mgr.get_session(&id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn respects_max_sessions() {
|
||||
fn respects_max_pipeline_sessions() {
|
||||
let mut mgr = SessionManager::new(1);
|
||||
mgr.create_session([1u8; 16], PipelineConfig::default());
|
||||
let result = mgr.create_session([2u8; 16], PipelineConfig::default());
|
||||
mgr.create_pipeline_session([1u8; 16], PipelineConfig::default());
|
||||
let result = mgr.create_pipeline_session([2u8; 16], PipelineConfig::default());
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
@@ -129,10 +230,73 @@ mod tests {
|
||||
fn expire_idle_removes_old() {
|
||||
let mut mgr = SessionManager::new(10);
|
||||
let id = [1u8; 16];
|
||||
mgr.create_session(id, PipelineConfig::default());
|
||||
// Session has last_activity_ms = 0, current time = 60000, timeout = 30000
|
||||
mgr.create_pipeline_session(id, PipelineConfig::default());
|
||||
let expired = mgr.expire_idle(60_000, 30_000);
|
||||
assert_eq!(expired, 1);
|
||||
assert_eq!(mgr.total_count(), 0);
|
||||
assert_eq!(mgr.pipeline_total_count(), 0);
|
||||
}
|
||||
|
||||
// ── Concurrent session (room-mode) tests ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn create_and_remove() {
|
||||
let mut mgr = SessionManager::new(10);
|
||||
let id = mgr.create_session("room-a", Some("fp123".into())).unwrap();
|
||||
assert_eq!(mgr.active_count(), 1);
|
||||
mgr.remove_session(id);
|
||||
assert_eq!(mgr.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_sessions_enforced() {
|
||||
let mut mgr = SessionManager::new(2);
|
||||
mgr.create_session("r1", None).unwrap();
|
||||
mgr.create_session("r2", None).unwrap();
|
||||
let err = mgr.create_session("r3", None);
|
||||
assert!(err.is_err());
|
||||
assert!(err.unwrap_err().contains("max sessions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sessions_in_room_tracking() {
|
||||
let mut mgr = SessionManager::new(10);
|
||||
let a1 = mgr.create_session("alpha", None).unwrap();
|
||||
let _a2 = mgr.create_session("alpha", None).unwrap();
|
||||
let _b1 = mgr.create_session("beta", None).unwrap();
|
||||
|
||||
let alpha_ids = mgr.sessions_in_room("alpha");
|
||||
assert_eq!(alpha_ids.len(), 2);
|
||||
assert!(alpha_ids.contains(&a1));
|
||||
|
||||
let beta_ids = mgr.sessions_in_room("beta");
|
||||
assert_eq!(beta_ids.len(), 1);
|
||||
|
||||
let empty = mgr.sessions_in_room("gamma");
|
||||
assert!(empty.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_info_returns_correct_data() {
|
||||
let mut mgr = SessionManager::new(10);
|
||||
let id = mgr.create_session("room-x", Some("alice-fp".into())).unwrap();
|
||||
|
||||
let info = mgr.session_info(id).expect("session should exist");
|
||||
assert_eq!(info.room_name, "room-x");
|
||||
assert_eq!(info.fingerprint.as_deref(), Some("alice-fp"));
|
||||
assert_eq!(info.state, SessionState::Active);
|
||||
|
||||
// Non-existent session returns None
|
||||
assert!(mgr.session_info([0xFFu8; 16]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_sessions_shared_across_both_layers() {
|
||||
let mut mgr = SessionManager::new(2);
|
||||
// One pipeline session + one tracked session = 2 = at capacity
|
||||
mgr.create_pipeline_session([1u8; 16], PipelineConfig::default());
|
||||
mgr.create_session("room", None).unwrap();
|
||||
// Both layers should now reject
|
||||
assert!(mgr.create_session("room", None).is_err());
|
||||
assert!(mgr.create_pipeline_session([2u8; 16], PipelineConfig::default()).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
152
crates/wzp-relay/src/trunk.rs
Normal file
152
crates/wzp-relay/src/trunk.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
//! Trunk batching — accumulates media packets from multiple sessions into
|
||||
//! [`TrunkFrame`]s that fit inside a single QUIC datagram.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use wzp_proto::packet::{TrunkEntry, TrunkFrame};
|
||||
|
||||
/// Batches individual session packets into [`TrunkFrame`]s.
|
||||
///
|
||||
/// A trunk frame is flushed when any of the following thresholds are hit:
|
||||
/// - `max_entries` — maximum number of packets per trunk.
|
||||
/// - `max_bytes` — maximum total wire size (should fit one UDP datagram).
|
||||
///
|
||||
/// The caller is responsible for timer-based flushing using [`flush_interval`]
|
||||
/// and calling [`flush`] when the interval expires.
|
||||
pub struct TrunkBatcher {
|
||||
pending: TrunkFrame,
|
||||
/// Current accumulated wire size of the pending frame.
|
||||
pending_bytes: usize,
|
||||
/// Maximum packets per trunk (default 10).
|
||||
pub max_entries: usize,
|
||||
/// Maximum total wire bytes per trunk (default 1200, fits in one UDP datagram).
|
||||
pub max_bytes: usize,
|
||||
/// Maximum wait before flushing (default 5 ms). Used by the caller for timer scheduling.
|
||||
pub flush_interval: Duration,
|
||||
}
|
||||
|
||||
impl TrunkBatcher {
|
||||
/// Header size: the 2-byte count prefix present in every TrunkFrame.
|
||||
const FRAME_HEADER: usize = 2;
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: TrunkFrame::new(),
|
||||
pending_bytes: Self::FRAME_HEADER,
|
||||
max_entries: 10,
|
||||
max_bytes: 1200,
|
||||
flush_interval: Duration::from_millis(5),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a session packet. Returns `Some(frame)` if the batch is now full
|
||||
/// and was flushed, `None` if more room remains.
|
||||
pub fn push(&mut self, session_id: [u8; 2], payload: Bytes) -> Option<TrunkFrame> {
|
||||
let entry_wire = TrunkEntry::OVERHEAD + payload.len();
|
||||
|
||||
// If adding this entry would exceed limits, flush first.
|
||||
if self.should_flush_with(entry_wire) && !self.pending.is_empty() {
|
||||
let frame = self.take_pending();
|
||||
// Then start a new batch with this entry.
|
||||
self.pending.push(session_id, payload);
|
||||
self.pending_bytes += entry_wire;
|
||||
return Some(frame);
|
||||
}
|
||||
|
||||
self.pending.push(session_id, payload);
|
||||
self.pending_bytes += entry_wire;
|
||||
|
||||
if self.should_flush() {
|
||||
Some(self.take_pending())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush the current pending frame if non-empty.
|
||||
pub fn flush(&mut self) -> Option<TrunkFrame> {
|
||||
if self.pending.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.take_pending())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the pending batch has reached `max_entries` or `max_bytes`.
|
||||
pub fn should_flush(&self) -> bool {
|
||||
self.pending.len() >= self.max_entries || self.pending_bytes >= self.max_bytes
|
||||
}
|
||||
|
||||
// --- private helpers ---
|
||||
|
||||
/// Would adding `extra_bytes` exceed a threshold?
|
||||
fn should_flush_with(&self, extra_bytes: usize) -> bool {
|
||||
self.pending.len() + 1 > self.max_entries
|
||||
|| self.pending_bytes + extra_bytes > self.max_bytes
|
||||
}
|
||||
|
||||
/// Take the pending frame out, resetting state.
|
||||
fn take_pending(&mut self) -> TrunkFrame {
|
||||
let frame = std::mem::replace(&mut self.pending, TrunkFrame::new());
|
||||
self.pending_bytes = Self::FRAME_HEADER;
|
||||
frame
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TrunkBatcher {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn trunk_batcher_fills_and_flushes() {
|
||||
let mut batcher = TrunkBatcher::new();
|
||||
batcher.max_entries = 3;
|
||||
batcher.max_bytes = 4096; // large enough to not interfere
|
||||
|
||||
// First two pushes should not flush.
|
||||
assert!(batcher.push([0, 1], Bytes::from_static(b"aaa")).is_none());
|
||||
assert!(batcher.push([0, 2], Bytes::from_static(b"bbb")).is_none());
|
||||
// Third push should trigger flush (max_entries = 3).
|
||||
let frame = batcher
|
||||
.push([0, 3], Bytes::from_static(b"ccc"))
|
||||
.expect("should flush at max_entries");
|
||||
assert_eq!(frame.len(), 3);
|
||||
assert_eq!(frame.packets[0].session_id, [0, 1]);
|
||||
assert_eq!(frame.packets[2].payload, Bytes::from_static(b"ccc"));
|
||||
|
||||
// Batcher is now empty.
|
||||
assert!(batcher.flush().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trunk_batcher_respects_max_bytes() {
|
||||
let mut batcher = TrunkBatcher::new();
|
||||
batcher.max_entries = 100; // won't be the trigger
|
||||
// Frame header (2) + one entry overhead (4) + 50 payload = 56
|
||||
// Two entries: 2 + 2*(4+50) = 110
|
||||
// Three entries: 2 + 3*54 = 164
|
||||
batcher.max_bytes = 120; // allow at most 2 entries of 50-byte payload
|
||||
|
||||
let big = Bytes::from(vec![0xAA; 50]);
|
||||
assert!(batcher.push([0, 1], big.clone()).is_none()); // 56 bytes
|
||||
// Second push: 56 + 54 = 110 < 120, fits
|
||||
assert!(batcher.push([0, 2], big.clone()).is_none());
|
||||
// Third push would be 164 > 120, so existing batch flushes first
|
||||
let frame = batcher
|
||||
.push([0, 3], big.clone())
|
||||
.expect("should flush on max_bytes");
|
||||
assert_eq!(frame.len(), 2);
|
||||
|
||||
// The third entry is now pending
|
||||
let remaining = batcher.flush().unwrap();
|
||||
assert_eq!(remaining.len(), 1);
|
||||
assert_eq!(remaining.packets[0].session_id, [0, 3]);
|
||||
}
|
||||
}
|
||||
243
crates/wzp-relay/src/ws.rs
Normal file
243
crates/wzp-relay/src/ws.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
//! WebSocket transport for browser clients.
|
||||
//!
|
||||
//! Browsers connect via `GET /ws/{room}` → WebSocket upgrade.
|
||||
//! First message must be auth JSON (if auth is enabled).
|
||||
//! Subsequent messages are binary PCM frames forwarded to/from the room.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Path, State, WebSocketUpgrade,
|
||||
},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tower_http::services::ServeDir;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::auth;
|
||||
use crate::metrics::RelayMetrics;
|
||||
use crate::presence::PresenceRegistry;
|
||||
use crate::room::RoomManager;
|
||||
use crate::session_mgr::SessionManager;
|
||||
|
||||
/// Shared state for WebSocket handlers.
|
||||
#[derive(Clone)]
|
||||
pub struct WsState {
|
||||
pub room_mgr: Arc<Mutex<RoomManager>>,
|
||||
pub session_mgr: Arc<Mutex<SessionManager>>,
|
||||
pub auth_url: Option<String>,
|
||||
pub metrics: Arc<RelayMetrics>,
|
||||
pub presence: Arc<Mutex<PresenceRegistry>>,
|
||||
}
|
||||
|
||||
/// Start the WebSocket + static file server.
|
||||
pub async fn run_ws_server(port: u16, state: WsState, static_dir: Option<String>) {
|
||||
let mut app = Router::new()
|
||||
.route("/ws/{room}", get(ws_upgrade_handler))
|
||||
.with_state(state);
|
||||
|
||||
if let Some(dir) = static_dir {
|
||||
info!(dir = %dir, "serving static files");
|
||||
app = app.fallback_service(ServeDir::new(dir));
|
||||
}
|
||||
|
||||
let addr: SocketAddr = ([0, 0, 0, 0], port).into();
|
||||
info!(%addr, "WebSocket server listening");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.expect("failed to bind WS listener");
|
||||
axum::serve(listener, app).await.expect("WS server failed");
|
||||
}
|
||||
|
||||
async fn ws_upgrade_handler(
|
||||
Path(room): Path<String>,
|
||||
State(state): State<WsState>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_ws_connection(socket, room, state))
|
||||
}
|
||||
|
||||
async fn handle_ws_connection(socket: WebSocket, room: String, state: WsState) {
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// 1. Auth: if auth_url is set, first message must be {"type":"auth","token":"..."}
|
||||
let fingerprint: Option<String> = if let Some(ref auth_url) = state.auth_url {
|
||||
match ws_rx.next().await {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
match serde_json::from_str::<serde_json::Value>(&text) {
|
||||
Ok(parsed) if parsed["type"] == "auth" => {
|
||||
if let Some(token) = parsed["token"].as_str() {
|
||||
match auth::validate_token(auth_url, token).await {
|
||||
Ok(client) => {
|
||||
state.metrics.auth_attempts.with_label_values(&["ok"]).inc();
|
||||
info!(fingerprint = %client.fingerprint, "WS authenticated");
|
||||
let _ = ws_tx
|
||||
.send(Message::Text(r#"{"type":"auth_ok"}"#.into()))
|
||||
.await;
|
||||
Some(client.fingerprint)
|
||||
}
|
||||
Err(e) => {
|
||||
state
|
||||
.metrics
|
||||
.auth_attempts
|
||||
.with_label_values(&["fail"])
|
||||
.inc();
|
||||
let _ = ws_tx
|
||||
.send(Message::Text(
|
||||
format!(r#"{{"type":"auth_error","error":"{e}"}}"#)
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
warn!("WS auth failed: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("WS auth: missing token field");
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("WS: expected auth message as first frame");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("WS: connection closed before auth");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let _ = ws_tx
|
||||
.send(Message::Text(r#"{"type":"auth_ok"}"#.into()))
|
||||
.await;
|
||||
None
|
||||
};
|
||||
|
||||
// 2. Create mpsc channel for outbound frames (room → browser)
|
||||
let (tx, mut rx) = mpsc::channel::<Bytes>(64);
|
||||
|
||||
// 3. Create session
|
||||
let session_id = {
|
||||
let mut smgr = state.session_mgr.lock().await;
|
||||
match smgr.create_session(&room, fingerprint.clone()) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
error!(room = %room, "WS session rejected: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
state.metrics.active_sessions.inc();
|
||||
|
||||
// 4. Join room with WS sender
|
||||
let addr: SocketAddr = ([0, 0, 0, 0], 0).into();
|
||||
let participant_id = {
|
||||
let mut mgr = state.room_mgr.lock().await;
|
||||
match mgr.join_ws(&room, addr, tx, fingerprint.as_deref()) {
|
||||
Ok(id) => {
|
||||
state.metrics.active_rooms.set(mgr.list().len() as i64);
|
||||
id
|
||||
}
|
||||
Err(e) => {
|
||||
error!(room = %room, "WS room join denied: {e}");
|
||||
state.metrics.active_sessions.dec();
|
||||
let mut smgr = state.session_mgr.lock().await;
|
||||
smgr.remove_session(session_id);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Register presence
|
||||
if let Some(ref fp) = fingerprint {
|
||||
let mut reg = state.presence.lock().await;
|
||||
reg.register_local(fp, None, Some(room.clone()));
|
||||
}
|
||||
|
||||
info!(room = %room, participant = participant_id, "WS client joined");
|
||||
|
||||
// 6. Outbound task: mpsc rx → WS binary frames
|
||||
let send_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if ws_tx
|
||||
.send(Message::Binary(data.to_vec().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 7. Inbound: WS recv → fan-out to room
|
||||
loop {
|
||||
match ws_rx.next().await {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
let others = {
|
||||
let mgr = state.room_mgr.lock().await;
|
||||
mgr.others(&room, participant_id)
|
||||
};
|
||||
for other in &others {
|
||||
let _ = other.send_raw(&data).await;
|
||||
}
|
||||
state
|
||||
.metrics
|
||||
.packets_forwarded
|
||||
.inc_by(others.len() as u64);
|
||||
state
|
||||
.metrics
|
||||
.bytes_forwarded
|
||||
.inc_by(data.len() as u64 * others.len() as u64);
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Cleanup
|
||||
send_task.abort();
|
||||
info!(room = %room, participant = participant_id, "WS client disconnected");
|
||||
|
||||
if let Some(ref fp) = fingerprint {
|
||||
let mut reg = state.presence.lock().await;
|
||||
reg.unregister_local(fp);
|
||||
}
|
||||
|
||||
{
|
||||
let mut mgr = state.room_mgr.lock().await;
|
||||
mgr.leave(&room, participant_id);
|
||||
state.metrics.active_rooms.set(mgr.list().len() as i64);
|
||||
}
|
||||
|
||||
let session_id_str: String = session_id.iter().map(|b| format!("{b:02x}")).collect();
|
||||
state.metrics.remove_session_metrics(&session_id_str);
|
||||
state.metrics.active_sessions.dec();
|
||||
|
||||
{
|
||||
let mut smgr = state.session_mgr.lock().await;
|
||||
smgr.remove_session(session_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ws_state_is_clone() {
|
||||
// WsState must be Clone for axum's State extractor
|
||||
fn assert_clone<T: Clone>() {}
|
||||
assert_clone::<WsState>();
|
||||
}
|
||||
}
|
||||
295
crates/wzp-relay/tests/handshake_integration.rs
Normal file
295
crates/wzp-relay/tests/handshake_integration.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
//! WZP-S-5 integration tests: crypto handshake wired into live QUIC path.
|
||||
//!
|
||||
//! Verifies that `perform_handshake` (client/caller) and `accept_handshake`
|
||||
//! (relay/callee) complete successfully over a real in-process QUIC connection
|
||||
//! and produce usable `CryptoSession` values.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
||||
use wzp_client::perform_handshake;
|
||||
use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_proto::{MediaTransport, SignalMessage};
|
||||
use wzp_relay::handshake::accept_handshake;
|
||||
use wzp_transport::{client_config, create_endpoint, server_config, QuinnTransport};
|
||||
|
||||
/// Establish a QUIC connection and wrap both sides in `QuinnTransport`.
|
||||
///
|
||||
/// Returns (client_transport, server_transport, _endpoints) where the endpoint
|
||||
/// tuple must be kept alive for the duration of the test to avoid premature
|
||||
/// connection teardown.
|
||||
async fn connected_pair() -> (Arc<QuinnTransport>, Arc<QuinnTransport>, (quinn::Endpoint, quinn::Endpoint)) {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let (sc, _cert_der) = server_config();
|
||||
let server_addr: SocketAddr = (Ipv4Addr::LOCALHOST, 0).into();
|
||||
let server_ep = create_endpoint(server_addr, Some(sc)).expect("server endpoint");
|
||||
let server_listen = server_ep.local_addr().expect("server local addr");
|
||||
|
||||
let client_addr: SocketAddr = (Ipv4Addr::LOCALHOST, 0).into();
|
||||
let client_ep = create_endpoint(client_addr, None).expect("client endpoint");
|
||||
|
||||
let server_ep_clone = server_ep.clone();
|
||||
let accept_fut = tokio::spawn(async move {
|
||||
let conn = wzp_transport::accept(&server_ep_clone).await.expect("accept");
|
||||
Arc::new(QuinnTransport::new(conn))
|
||||
});
|
||||
|
||||
let client_conn =
|
||||
wzp_transport::connect(&client_ep, server_listen, "localhost", client_config())
|
||||
.await
|
||||
.expect("connect");
|
||||
let client_transport = Arc::new(QuinnTransport::new(client_conn));
|
||||
|
||||
let server_transport = accept_fut.await.expect("join accept task");
|
||||
|
||||
(client_transport, server_transport, (server_ep, client_ep))
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 1: handshake_succeeds
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn handshake_succeeds() {
|
||||
let (client_transport, server_transport, _endpoints) = connected_pair().await;
|
||||
|
||||
let caller_seed: [u8; 32] = [0xAA; 32];
|
||||
let callee_seed: [u8; 32] = [0xBB; 32];
|
||||
|
||||
// Clone Arc so the server transport stays alive in the main task too.
|
||||
let server_t = Arc::clone(&server_transport);
|
||||
let callee_handle = tokio::spawn(async move {
|
||||
accept_handshake(server_t.as_ref(), &callee_seed).await
|
||||
});
|
||||
|
||||
let caller_session = perform_handshake(client_transport.as_ref(), &caller_seed)
|
||||
.await
|
||||
.expect("perform_handshake should succeed");
|
||||
|
||||
let (callee_session, chosen_profile) = callee_handle
|
||||
.await
|
||||
.expect("join callee task")
|
||||
.expect("accept_handshake should succeed");
|
||||
|
||||
// Both sides should have derived a working CryptoSession.
|
||||
// Verify by encrypting on one side and decrypting on the other.
|
||||
let header = b"test-header";
|
||||
let plaintext = b"hello warzone";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
let mut caller_session = caller_session;
|
||||
let mut callee_session = callee_session;
|
||||
|
||||
caller_session
|
||||
.encrypt(header, plaintext, &mut ciphertext)
|
||||
.expect("encrypt");
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
callee_session
|
||||
.decrypt(header, &ciphertext, &mut decrypted)
|
||||
.expect("decrypt");
|
||||
|
||||
assert_eq!(&decrypted, plaintext);
|
||||
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
|
||||
|
||||
// Keep transports alive until test completes.
|
||||
drop(server_transport);
|
||||
drop(client_transport);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 2: handshake_verifies_identity
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn handshake_verifies_identity() {
|
||||
let (client_transport, server_transport, _endpoints) = connected_pair().await;
|
||||
|
||||
// Two completely different seeds => different identity keys.
|
||||
let caller_seed: [u8; 32] = [0x11; 32];
|
||||
let callee_seed: [u8; 32] = [0x22; 32];
|
||||
|
||||
// Confirm the seeds produce different identity public keys.
|
||||
let caller_kx = WarzoneKeyExchange::from_identity_seed(&caller_seed);
|
||||
let callee_kx = WarzoneKeyExchange::from_identity_seed(&callee_seed);
|
||||
assert_ne!(
|
||||
caller_kx.identity_public_key(),
|
||||
callee_kx.identity_public_key(),
|
||||
"different seeds must produce different identity keys"
|
||||
);
|
||||
|
||||
let server_t = Arc::clone(&server_transport);
|
||||
let callee_handle = tokio::spawn(async move {
|
||||
accept_handshake(server_t.as_ref(), &callee_seed).await
|
||||
});
|
||||
|
||||
let caller_session = perform_handshake(client_transport.as_ref(), &caller_seed)
|
||||
.await
|
||||
.expect("handshake must succeed even with different identities");
|
||||
|
||||
let (callee_session, _profile) = callee_handle
|
||||
.await
|
||||
.expect("join")
|
||||
.expect("accept_handshake must succeed");
|
||||
|
||||
// Cross-encrypt/decrypt to prove the shared session works.
|
||||
let header = b"id-test";
|
||||
let plaintext = b"identity verified";
|
||||
|
||||
let mut ct = Vec::new();
|
||||
let mut caller_session = caller_session;
|
||||
let mut callee_session = callee_session;
|
||||
|
||||
caller_session
|
||||
.encrypt(header, plaintext, &mut ct)
|
||||
.expect("encrypt");
|
||||
|
||||
let mut pt = Vec::new();
|
||||
callee_session
|
||||
.decrypt(header, &ct, &mut pt)
|
||||
.expect("decrypt");
|
||||
|
||||
assert_eq!(&pt, plaintext);
|
||||
|
||||
drop(server_transport);
|
||||
drop(client_transport);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 3: auth_then_handshake
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auth_then_handshake() {
|
||||
let (client_transport, server_transport, _endpoints) = connected_pair().await;
|
||||
|
||||
let caller_seed: [u8; 32] = [0xCC; 32];
|
||||
let callee_seed: [u8; 32] = [0xDD; 32];
|
||||
|
||||
// The callee side: first consume the AuthToken, then run accept_handshake.
|
||||
let server_t = Arc::clone(&server_transport);
|
||||
let callee_handle = tokio::spawn(async move {
|
||||
// 1. Receive AuthToken
|
||||
let auth_msg = server_t
|
||||
.recv_signal()
|
||||
.await
|
||||
.expect("recv_signal should succeed")
|
||||
.expect("should receive a message");
|
||||
|
||||
let token = match auth_msg {
|
||||
SignalMessage::AuthToken { token } => token,
|
||||
other => panic!("expected AuthToken, got {:?}", std::mem::discriminant(&other)),
|
||||
};
|
||||
|
||||
// 2. Run the cryptographic handshake
|
||||
let (session, profile) = accept_handshake(server_t.as_ref(), &callee_seed)
|
||||
.await
|
||||
.expect("accept_handshake after auth");
|
||||
|
||||
(token, session, profile)
|
||||
});
|
||||
|
||||
// Caller side: send AuthToken first, then perform_handshake.
|
||||
let auth = SignalMessage::AuthToken {
|
||||
token: "bearer-test-token-12345".to_string(),
|
||||
};
|
||||
client_transport
|
||||
.send_signal(&auth)
|
||||
.await
|
||||
.expect("send AuthToken");
|
||||
|
||||
let caller_session = perform_handshake(client_transport.as_ref(), &caller_seed)
|
||||
.await
|
||||
.expect("perform_handshake after auth");
|
||||
|
||||
let (received_token, callee_session, _profile) = callee_handle
|
||||
.await
|
||||
.expect("join callee task");
|
||||
|
||||
// Verify the auth token was received correctly.
|
||||
assert_eq!(received_token, "bearer-test-token-12345");
|
||||
|
||||
// Verify the crypto session works after the auth preamble.
|
||||
let header = b"auth-hdr";
|
||||
let plaintext = b"post-auth payload";
|
||||
|
||||
let mut ct = Vec::new();
|
||||
let mut caller_session = caller_session;
|
||||
let mut callee_session = callee_session;
|
||||
|
||||
caller_session
|
||||
.encrypt(header, plaintext, &mut ct)
|
||||
.expect("encrypt");
|
||||
|
||||
let mut pt = Vec::new();
|
||||
callee_session
|
||||
.decrypt(header, &ct, &mut pt)
|
||||
.expect("decrypt");
|
||||
|
||||
assert_eq!(&pt, plaintext);
|
||||
|
||||
drop(server_transport);
|
||||
drop(client_transport);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 4: handshake_rejects_bad_signature
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn handshake_rejects_bad_signature() {
|
||||
let (client_transport, server_transport, _endpoints) = connected_pair().await;
|
||||
|
||||
let caller_seed: [u8; 32] = [0xEE; 32];
|
||||
let callee_seed: [u8; 32] = [0xFF; 32];
|
||||
|
||||
// Spawn callee -- it should reject the tampered CallOffer.
|
||||
let server_t = Arc::clone(&server_transport);
|
||||
let callee_handle = tokio::spawn(async move {
|
||||
accept_handshake(server_t.as_ref(), &callee_seed).await
|
||||
});
|
||||
|
||||
// Manually build a CallOffer with a corrupted signature.
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(&caller_seed);
|
||||
let identity_pub = kx.identity_public_key();
|
||||
let ephemeral_pub = kx.generate_ephemeral();
|
||||
|
||||
let mut sign_data = Vec::with_capacity(32 + 10);
|
||||
sign_data.extend_from_slice(&ephemeral_pub);
|
||||
sign_data.extend_from_slice(b"call-offer");
|
||||
let mut signature = kx.sign(&sign_data);
|
||||
|
||||
// Tamper: flip bits in the signature.
|
||||
for byte in signature.iter_mut().take(8) {
|
||||
*byte ^= 0xFF;
|
||||
}
|
||||
|
||||
let bad_offer = SignalMessage::CallOffer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
};
|
||||
|
||||
client_transport
|
||||
.send_signal(&bad_offer)
|
||||
.await
|
||||
.expect("send tampered CallOffer");
|
||||
|
||||
// The callee should return an error about signature verification.
|
||||
let result = callee_handle.await.expect("join callee task");
|
||||
match result {
|
||||
Ok(_) => panic!("accept_handshake must reject a bad signature"),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
assert!(
|
||||
err_msg.contains("signature verification failed"),
|
||||
"error should mention signature verification, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
drop(server_transport);
|
||||
drop(client_transport);
|
||||
}
|
||||
@@ -136,6 +136,11 @@ impl PathMonitor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get raw packet counts for debugging.
|
||||
pub fn counts(&self) -> (u64, u64) {
|
||||
(self.total_sent, self.total_received)
|
||||
}
|
||||
|
||||
/// Estimate bandwidth in kbps from bytes received over time.
|
||||
fn estimate_bandwidth_kbps(&self) -> u32 {
|
||||
if let (Some(first), Some(last)) = (self.first_recv_time_ms, self.last_recv_time_ms) {
|
||||
@@ -149,6 +154,27 @@ impl PathMonitor {
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
/// Detect whether a network handoff likely occurred.
|
||||
///
|
||||
/// Returns `true` if the most recent RTT jitter measurement exceeds 3x
|
||||
/// the EWMA-smoothed jitter average, which is characteristic of a cellular
|
||||
/// network handoff (tower switch, WiFi-to-cellular transition, etc.).
|
||||
pub fn detect_handoff(&self) -> bool {
|
||||
// We need at least two RTT observations to have a meaningful jitter value,
|
||||
// and the EWMA must be non-zero to avoid division/multiplication by zero.
|
||||
if self.jitter_ewma <= 0.0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let (Some(last_rtt), Some(_)) = (self.last_rtt_ms, Some(self.rtt_ewma)) {
|
||||
// Compute the most recent instantaneous jitter (RTT deviation from EWMA)
|
||||
let instant_jitter = (last_rtt - self.rtt_ewma).abs();
|
||||
instant_jitter > self.jitter_ewma * 3.0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PathMonitor {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use wzp_proto::packet::TrunkFrame;
|
||||
use wzp_proto::{MediaPacket, MediaTransport, PathQuality, SignalMessage, TransportError};
|
||||
|
||||
use crate::datagram;
|
||||
@@ -32,10 +33,67 @@ impl QuinnTransport {
|
||||
&self.connection
|
||||
}
|
||||
|
||||
/// Close the QUIC connection immediately (synchronous, no async needed).
|
||||
/// The relay will detect the close and remove this participant from the room.
|
||||
pub fn close_now(&self) {
|
||||
self.connection.close(quinn::VarInt::from_u32(0), b"hangup");
|
||||
}
|
||||
|
||||
/// Feed an external RTT observation (e.g. from QUIC path stats) into the path monitor.
|
||||
pub fn feed_rtt(&self, rtt_ms: u32) {
|
||||
self.path_monitor.lock().unwrap().observe_rtt(rtt_ms);
|
||||
}
|
||||
|
||||
/// Get raw packet counts from path monitor (sent, received).
|
||||
pub fn monitor_counts(&self) -> (u64, u64) {
|
||||
self.path_monitor.lock().unwrap().counts()
|
||||
}
|
||||
|
||||
/// Get the maximum datagram payload size, if datagrams are supported.
|
||||
pub fn max_datagram_size(&self) -> Option<usize> {
|
||||
datagram::max_datagram_payload(&self.connection)
|
||||
}
|
||||
|
||||
/// Send an encoded [`TrunkFrame`] as a single QUIC datagram.
|
||||
pub fn send_trunk(&self, frame: &TrunkFrame) -> Result<(), TransportError> {
|
||||
let data = frame.encode();
|
||||
|
||||
if let Some(max_size) = self.connection.max_datagram_size() {
|
||||
if data.len() > max_size {
|
||||
return Err(TransportError::DatagramTooLarge {
|
||||
size: data.len(),
|
||||
max: max_size,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.connection.send_datagram(data).map_err(|e| {
|
||||
TransportError::Internal(format!("send trunk datagram error: {e}"))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a single QUIC datagram and decode it as a [`TrunkFrame`].
|
||||
///
|
||||
/// Returns `Ok(None)` on connection close, `Ok(Some(frame))` on success,
|
||||
/// or an error on malformed data / transport failure.
|
||||
pub async fn recv_trunk(&self) -> Result<Option<TrunkFrame>, TransportError> {
|
||||
let data = match self.connection.read_datagram().await {
|
||||
Ok(data) => data,
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => return Ok(None),
|
||||
Err(quinn::ConnectionError::LocallyClosed) => return Ok(None),
|
||||
Err(e) => {
|
||||
return Err(TransportError::Internal(format!(
|
||||
"recv trunk datagram error: {e}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
TrunkFrame::decode(&data)
|
||||
.map(Some)
|
||||
.ok_or_else(|| TransportError::Internal("malformed trunk frame".into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -18,6 +18,9 @@ tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
anyhow = "1"
|
||||
wzp-relay = { path = "../wzp-relay" }
|
||||
serde_json = "1"
|
||||
rustls-pemfile = "2"
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
tower-http = { version = "0.6", features = ["fs"] }
|
||||
futures = "0.3"
|
||||
@@ -26,6 +29,7 @@ rcgen = "0.13"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rustls-pki-types = "1"
|
||||
tokio-rustls = "0.26"
|
||||
prometheus = "0.13"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-web"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user