Compare commits
457 Commits
6228ab32c1
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01f55caa96 | ||
|
|
0f93a2b745 | ||
|
|
2b93bd4b45 | ||
|
|
bc021517c0 | ||
|
|
739bdaf3ab | ||
|
|
bc1668ed96 | ||
|
|
77b036439b | ||
|
|
0ebc73ab13 | ||
|
|
394987a349 | ||
|
|
2aa6582585 | ||
|
|
ca987d547c | ||
|
|
5a13f12334 | ||
|
|
b0a3b1f18e | ||
|
|
32c07d1b61 | ||
|
|
5d05b021aa | ||
|
|
4ac62d99e0 | ||
|
|
4ebb2dac2d | ||
|
|
52a6f5e048 | ||
|
|
15af58a95d | ||
|
|
ed8a7ae5aa | ||
|
|
12b0d9738f | ||
|
|
f78794f4b6 | ||
|
|
f3e3ee5ed0 | ||
|
|
f28f39d814 | ||
|
|
1e729e4b1d | ||
|
|
086d0a4845 | ||
|
|
9334aa5ccd | ||
|
|
553c8a4ce1 | ||
|
|
8d8dddbd35 | ||
|
|
f16d650721 | ||
|
|
31f2fdef1e | ||
|
|
fc9908cd4c | ||
|
|
517d0ebfe0 | ||
|
|
cf4940417e | ||
|
|
ffded2a913 | ||
|
|
283edd38eb | ||
|
|
fdfaed5390 | ||
|
|
dbbab0decf | ||
|
|
5fda5ecc52 | ||
|
|
2bbb664df4 | ||
|
|
2f1a9f74d5 | ||
|
|
b197651557 | ||
|
|
9c41d1acdd | ||
|
|
e34c40dc0f | ||
|
|
c48cb6fbcb | ||
|
|
2e0bdc5904 | ||
|
|
276ecc660e | ||
|
|
001d94f9ae | ||
|
|
36b0421d68 | ||
|
|
828fbea2ea | ||
|
|
cc5aef2534 | ||
|
|
397f9d2141 | ||
|
|
410c2a4335 | ||
|
|
81042ac190 | ||
|
|
e177e63843 | ||
|
|
1f7d130de9 | ||
|
|
3356ba94c6 | ||
|
|
bb153a331d | ||
|
|
490d2d31c6 | ||
|
|
db69f7e9d1 | ||
|
|
f1b86e0fed | ||
|
|
8454835c18 | ||
|
|
017c371611 | ||
|
|
3220bd6151 | ||
|
|
e73f8a7150 | ||
|
|
1b4f7b0772 | ||
|
|
f3398adb95 | ||
|
|
54c1a35186 | ||
|
|
3de56cf1f9 | ||
|
|
fe1f9484bd | ||
|
|
0ef1f574ff | ||
|
|
b1c5837495 | ||
|
|
6f81487778 | ||
|
|
5cdb50160a | ||
|
|
30d26fc7f6 | ||
|
|
c93d302656 | ||
|
|
9680b6ff34 | ||
|
|
6b15b8f97c | ||
|
|
6385b93391 | ||
|
|
6eb94f079d | ||
|
|
5580b794a4 | ||
|
|
7c9ede9227 | ||
|
|
e8866c6632 | ||
|
|
8c6e88ea68 | ||
|
|
ffb92237be | ||
|
|
6af0539a72 | ||
|
|
217567383d | ||
|
|
98ed981805 | ||
|
|
01a3133544 | ||
|
|
25471c694f | ||
|
|
a058a83c91 | ||
|
|
9b8013ba7f | ||
|
|
defd8eab07 | ||
|
|
cc23e829b2 | ||
|
|
18c204c1ff | ||
|
|
1120c7b579 | ||
|
|
7e7391fdbb | ||
|
|
aa0362f318 | ||
|
|
bb23976076 | ||
|
|
18e5e75f33 | ||
|
|
488efcb614 | ||
|
|
8c360186df | ||
|
|
f06f9073ae | ||
|
|
6c49d7436f | ||
|
|
1de280fe04 | ||
|
|
bc6d327ebb | ||
|
|
c478224d67 | ||
|
|
16dcc75514 | ||
|
|
db5751985e | ||
|
|
c0dd6c06ff | ||
|
|
6805caae0e | ||
|
|
5a03da72d3 | ||
|
|
e3e63a40a0 | ||
|
|
7b4bce69d5 | ||
|
|
ec1bdf3cd5 | ||
|
|
ee14862376 | ||
|
|
f83361895e | ||
|
|
0857d190ed | ||
|
|
5d431c0721 | ||
|
|
8fcf1be341 | ||
|
|
9377a9009c | ||
|
|
4471797edf | ||
|
|
425c67a08a | ||
|
|
88ca3e099a | ||
|
|
1e82811cc1 | ||
|
|
81b5522942 | ||
|
|
d539a6dfb9 | ||
|
|
ba12aae439 | ||
|
|
fdb78e08bd | ||
|
|
3a51db998a | ||
|
|
a52b011fb5 | ||
|
|
2514151a89 | ||
|
|
f265fd772d | ||
|
|
9ae9441de4 | ||
|
|
d9e7e72978 | ||
|
|
8ff0c548a7 | ||
|
|
f17420aa98 | ||
|
|
d424515542 | ||
|
|
ea5fc17c34 | ||
|
|
1a7dd935ee | ||
|
|
a7c2261b70 | ||
|
|
eca0bb7531 | ||
|
|
d249b32ee5 | ||
|
|
22045bc5e6 | ||
|
|
766c9df442 | ||
|
|
6f43415285 | ||
|
|
24cc74d93c | ||
|
|
300ea66d13 | ||
|
|
114d69e488 | ||
|
|
15c237ceea | ||
|
|
a37c8b30fe | ||
|
|
137fe5f084 | ||
|
|
5dfb5b3581 | ||
|
|
fd0ccf8e99 | ||
|
|
2d4948a7b3 | ||
|
|
19703ff66c | ||
|
|
7e8dc400dc | ||
|
|
a798634b3d | ||
|
|
d89376016a | ||
|
|
678695776e | ||
|
|
4c1ad841e1 | ||
|
|
29cd23fe39 | ||
|
|
4d66d3769d | ||
|
|
002df15c5e | ||
|
|
1eb82d77b8 | ||
|
|
f843a934fe | ||
|
|
b79073c649 | ||
|
|
82b439595c | ||
|
|
1904b19d05 | ||
|
|
40955bd11c | ||
|
|
7554959baa | ||
|
|
0b62d3e22f | ||
|
|
4cfcd5117f | ||
|
|
bd6733b2e5 | ||
|
|
7d1b8f1fdc | ||
|
|
c2d298beb5 | ||
|
|
aee41a638d | ||
|
|
9fb92967eb | ||
|
|
9f2ff6a6ec | ||
|
|
134ee3a77f | ||
|
|
e61397ca85 | ||
|
|
f5542ef822 | ||
|
|
de007ec2fd | ||
|
|
0a973b234b | ||
|
|
026940d492 | ||
|
|
0ccf4ed6b5 | ||
|
|
847699bf66 | ||
|
|
6cd61fc63b | ||
|
|
50e6a50de4 | ||
|
|
0cb8d34b21 | ||
|
|
2427630472 | ||
|
|
16793be36f | ||
|
|
fa038df057 | ||
|
|
8990514417 | ||
|
|
1618ff6c9d | ||
|
|
05ec926317 | ||
|
|
b7a48bf13b | ||
|
|
e75b045470 | ||
|
|
20375eceb9 | ||
|
|
00deb97a5d | ||
|
|
da08723fe7 | ||
|
|
8cdf8d486a | ||
|
|
59ce52f8e8 | ||
|
|
39277bf3a0 | ||
|
|
8d903f16c6 | ||
|
|
921856eba9 | ||
|
|
7e7968b2f9 | ||
|
|
578ff8cff4 | ||
|
|
16890576fb | ||
|
|
daf7bcd9ba | ||
|
|
df1a45a5f5 | ||
|
|
dd0c714caa | ||
|
|
a7b2f850f1 | ||
|
|
575a39d07a | ||
|
|
d63d50cdc0 | ||
|
|
d269600aa7 | ||
|
|
dfbe21fe6e | ||
|
|
b83c31b5d1 | ||
|
|
1f607281fd | ||
|
|
7515417202 | ||
|
|
505a834c5b | ||
|
|
27bc264738 | ||
|
|
c27b39d553 | ||
|
|
6db5c25b54 | ||
|
|
54cbebd34e | ||
|
|
86526a7ad4 | ||
|
|
56e3417063 | ||
|
|
8ceb6f45d5 | ||
|
|
07873ea598 | ||
|
|
cc00f7cace | ||
|
|
eb9de988d6 | ||
|
|
4ba77c8c0e | ||
|
|
7b8a2d0fba | ||
|
|
5cd7a20152 | ||
|
|
a5c00fe5cb | ||
|
|
ec41f179cd | ||
|
|
4e9244eb00 | ||
|
|
03a80a3196 | ||
|
|
7fecf285ea | ||
|
|
0683dde5d3 | ||
|
|
53f57eea07 | ||
|
|
ff3f7e8e4f | ||
|
|
48d2bd4f65 | ||
|
|
234a798df2 | ||
|
|
fa042b130c | ||
|
|
990b6f1ee0 | ||
|
|
7949266e11 | ||
|
|
d774f5f8c5 | ||
|
|
2fd94651e4 | ||
|
|
da09fdb6e9 | ||
|
|
510eae2089 | ||
|
|
76a4c53e21 | ||
|
|
4c6aac654a | ||
|
|
4f2ad65418 | ||
|
|
0178cbd91d | ||
|
|
9e37201198 | ||
|
|
da106bd939 | ||
|
|
8c36fb5651 | ||
|
|
cfa9ff67cf | ||
|
|
96be740fd9 | ||
|
|
8c4d640f89 | ||
|
|
49f101d785 | ||
|
|
d7b37a5749 | ||
|
|
b35a6b7d92 | ||
|
|
0105b0fbf3 | ||
|
|
5beea7de40 | ||
|
|
fdbe502524 | ||
|
|
c769a476a2 | ||
|
|
7cc53aedc7 | ||
|
|
711137da96 | ||
|
|
6071eb1b02 | ||
|
|
c9cd043657 | ||
|
|
6dd62c94c9 | ||
|
|
4c998312aa | ||
|
|
22701830c2 | ||
|
|
47a037368c | ||
|
|
191e8761d5 | ||
|
|
0d74366592 | ||
|
|
0224ce654c | ||
|
|
aa240c6d83 | ||
|
|
d216dcc7a3 | ||
|
|
4250f1b44a | ||
|
|
a852cad15e | ||
|
|
19fd3dd9cc | ||
|
|
c69195fe06 | ||
|
|
ae4f366b05 | ||
|
|
f96d7ce3e1 | ||
|
|
530993854f | ||
|
|
e2e023d2bc | ||
|
|
5df9d418c9 | ||
|
|
2718402e96 | ||
|
|
1a8288c95f | ||
|
|
f015be63ec | ||
|
|
79e876126c | ||
|
|
903a07c1d4 | ||
|
|
af20fa418a | ||
|
|
b314138caf | ||
|
|
35642d1c54 | ||
|
|
6b8107504e | ||
|
|
7639aaf08d | ||
|
|
69ee3115b6 | ||
|
|
e6f77a78a7 | ||
|
|
04a985912a | ||
|
|
2288c1ae07 | ||
|
|
0d3f0d4dcb | ||
|
|
c184d5e1f3 | ||
|
|
5d8e743cbf | ||
|
|
6694aebfd9 | ||
|
|
d27e85ecf2 | ||
|
|
39ac181d63 | ||
|
|
3351cb6473 | ||
|
|
54a4d91f3e | ||
|
|
3b962bd4cb | ||
|
|
1118eac752 | ||
|
|
f935bd69cd | ||
|
|
1c684f6b47 | ||
|
|
c92db7e9b7 | ||
|
|
c3bd657224 | ||
|
|
8b79cdc6fc | ||
|
|
2eab56beec | ||
|
|
7dadc1ddd6 | ||
|
|
be0441295a | ||
|
|
b9f4e7f102 | ||
|
|
28f4a0fb6f | ||
|
|
3d76acf528 | ||
|
|
f4b5996bdf | ||
|
|
fc721c4217 | ||
|
|
5c24adf1c1 | ||
|
|
8dbda3e052 | ||
|
|
c8a3aaacb6 | ||
|
|
395a0c557e | ||
|
|
54cb6c3b71 | ||
|
|
da593f9510 | ||
|
|
a3ebf5616f | ||
|
|
ff6d0444c0 | ||
|
|
8080713098 | ||
|
|
e813362395 | ||
|
|
d52b8befd6 | ||
|
|
0abecf7fd8 | ||
|
|
f4cc3b1a6b | ||
|
|
af4c89f5f0 | ||
|
|
406461d460 | ||
|
|
7064f484af | ||
|
|
1d2222a25a | ||
|
|
270e139f20 | ||
|
|
d9b2e0fd53 | ||
|
|
898c1ea32b | ||
|
|
b00db5dfdc | ||
|
|
bc8bb3d790 | ||
|
|
ea51d068e6 | ||
|
|
7271942c6a | ||
|
|
da84ed332c | ||
|
|
e50925e05a | ||
|
|
6be36e43c2 | ||
|
|
2f2720802d | ||
|
|
087bfd2335 | ||
|
|
0a05e62c7f | ||
|
|
b97f32ce46 | ||
|
|
d66d583583 | ||
|
|
d06cf66538 | ||
|
|
7bddc6b5a6 | ||
|
|
c8bcc5c974 | ||
|
|
760126b6ab | ||
|
|
53f8bf8fff | ||
|
|
3b85604b41 | ||
|
|
a8c2011445 | ||
|
|
ded49bdb7b | ||
|
|
b3cdad0c75 | ||
|
|
fa3c7f1cef | ||
|
|
369347ce54 | ||
|
|
44f04b55e8 | ||
|
|
85c2146760 | ||
|
|
96ccb4f333 | ||
|
|
95a905e1b5 | ||
|
|
f7ccb67b02 | ||
|
|
4df08eadbd | ||
|
|
6d776097c8 | ||
|
|
68b56d9172 | ||
|
|
7973c8c6a3 | ||
|
|
3e9539e5da | ||
|
|
a1ccb3f390 | ||
|
|
7751439e2b | ||
|
|
20bc290c18 | ||
|
|
a8dc350a65 | ||
|
|
00fa109f07 | ||
|
|
1e40dec468 | ||
|
|
aecef0905d | ||
|
|
18f7faa279 | ||
|
|
eeb85aeac2 | ||
|
|
00b405aa87 | ||
|
|
d09e21965e | ||
|
|
97bcc79f9b | ||
|
|
9f7962a6cd | ||
|
|
264ef9c4d4 | ||
|
|
a9adb5cfd7 | ||
|
|
a39b074d6e | ||
|
|
9cab6e2347 | ||
|
|
8c9befb15d | ||
|
|
d36feb2b59 | ||
|
|
5e93cb74f2 | ||
|
|
b56b4a759c | ||
|
|
6f99841cc7 | ||
|
|
3f869a4cd7 | ||
|
|
baf82d935b | ||
|
|
3b0811ce2e | ||
|
|
9eed94850d | ||
|
|
5e9718aeb2 | ||
|
|
3093933602 | ||
|
|
4c6c909732 | ||
|
|
33fab9a049 | ||
|
|
31d2306915 | ||
|
|
2263e898e5 | ||
|
|
4af7c5f94c | ||
|
|
6597b5bd86 | ||
|
|
ae9d8526dd | ||
|
|
9ab57ba037 | ||
|
|
7806d4ec04 | ||
|
|
d31b81a21d | ||
|
|
4d54b6f9e4 | ||
|
|
c268ce419a | ||
|
|
61b6e67610 | ||
|
|
dddf5d2e2d | ||
|
|
ed272d29f8 | ||
|
|
2b3bdae440 | ||
|
|
21f5b24cbf | ||
|
|
9b733010ab | ||
|
|
80d5bd7628 | ||
|
|
4a195a923a | ||
|
|
f726f8cfa4 | ||
|
|
20922455bd | ||
|
|
e468454464 | ||
|
|
d1c96cd71f | ||
|
|
1b00b5e2a4 | ||
|
|
e6564bab57 | ||
|
|
cfb48df1ef | ||
|
|
aebf9156c0 | ||
|
|
9bbaec6b35 | ||
|
|
ba29d8354f | ||
|
|
0908507a7a | ||
|
|
860c90394d | ||
|
|
dc66b60d18 | ||
|
|
a9c4260b4e | ||
|
|
7eb136fcb3 | ||
|
|
550a124972 | ||
|
|
0835c36d0f | ||
|
|
6eb10327c1 | ||
|
|
50339542fa | ||
|
|
c67fa18f14 | ||
|
|
6c5c4cb671 | ||
|
|
8816f13df8 | ||
|
|
3804b0bf46 | ||
|
|
234f3c4bfe | ||
|
|
e97f278390 | ||
|
|
f6a77da948 | ||
|
|
82015a78af | ||
|
|
cb13af8abd | ||
|
|
0b8276b9c7 |
72
.agents/skills/caveman/SKILL.md
Normal file
72
.agents/skills/caveman/SKILL.md
Normal file
@@ -0,0 +1,72 @@
|
||||
---
|
||||
name: caveman
|
||||
description: >
|
||||
Ultra-compressed communication mode. Slash token usage ~75% by speaking like caveman
|
||||
while keeping full technical accuracy. Use when user says "caveman mode", "talk like caveman",
|
||||
"use caveman", "less tokens", "be brief", or invokes /caveman. Also auto-triggers
|
||||
when token efficiency is requested.
|
||||
---
|
||||
|
||||
# Caveman Mode
|
||||
|
||||
## Core Rule
|
||||
|
||||
Respond like smart caveman. Cut articles, filler, pleasantries. Keep all technical substance.
|
||||
|
||||
## Grammar
|
||||
|
||||
- Drop articles (a, an, the)
|
||||
- Drop filler (just, really, basically, actually, simply)
|
||||
- Drop pleasantries (sure, certainly, of course, happy to)
|
||||
- Short synonyms (big not extensive, fix not "implement a solution for")
|
||||
- No hedging (skip "it might be worth considering")
|
||||
- Fragments fine. No need full sentence
|
||||
- Technical terms stay exact. "Polymorphism" stays "polymorphism"
|
||||
- Code blocks unchanged. Caveman speak around code, not in code
|
||||
- Error messages quoted exact. Caveman only for explanation
|
||||
|
||||
## Pattern
|
||||
|
||||
```
|
||||
[thing] [action] [reason]. [next step].
|
||||
```
|
||||
|
||||
Not:
|
||||
> Sure! I'd be happy to help you with that. The issue you're experiencing is likely caused by...
|
||||
|
||||
Yes:
|
||||
> Bug in auth middleware. Token expiry check use `<` not `<=`. Fix:
|
||||
|
||||
## Examples
|
||||
|
||||
**User:** Why is my React component re-rendering?
|
||||
|
||||
**Normal (69 tokens):** "The reason your React component is re-rendering is likely because you're creating a new object reference on each render cycle. When you pass an inline object as a prop, React's shallow comparison sees it as a different object every time, which triggers a re-render. I'd recommend using useMemo to memoize the object."
|
||||
|
||||
**Caveman (19 tokens):** "New object ref each render. Inline object prop = new ref = re-render. Wrap in `useMemo`."
|
||||
|
||||
---
|
||||
|
||||
**User:** How do I set up a PostgreSQL connection pool?
|
||||
|
||||
**Caveman:**
|
||||
```
|
||||
Use `pg` pool:
|
||||
```
|
||||
```js
|
||||
const pool = new Pool({
|
||||
max: 20,
|
||||
idleTimeoutMillis: 30000,
|
||||
connectionTimeoutMillis: 2000,
|
||||
})
|
||||
```
|
||||
```
|
||||
max = concurrent connections. Keep under DB limit. idleTimeout kill stale conn.
|
||||
```
|
||||
|
||||
## Boundaries
|
||||
|
||||
- Code: write normal. Caveman English only
|
||||
- Git commits: normal
|
||||
- PR descriptions: normal
|
||||
- User say "stop caveman" or "normal mode": revert immediately
|
||||
@@ -7,6 +7,8 @@ on:
|
||||
- 'feat/*'
|
||||
tags:
|
||||
- 'v*'
|
||||
paths-ignore:
|
||||
- '.gitea/**'
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
|
||||
43
.gitea/workflows/mirror-github.yml
Normal file
43
.gitea/workflows/mirror-github.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Mirror to GitHub
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'feat/*'
|
||||
- 'feature/*'
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
mirror:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Push to GitHub
|
||||
env:
|
||||
GH_SSH_KEY: ${{ secrets.GH_SSH_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "${GH_SSH_KEY}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh-keyscan github.com >> ~/.ssh/known_hosts 2>/dev/null
|
||||
|
||||
git remote add github git@github.com:manawenuz/wzp.git
|
||||
|
||||
# Push the current branch
|
||||
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||
TAG="${GITHUB_REF#refs/tags/}"
|
||||
|
||||
if [ "${GITHUB_REF}" != "${GITHUB_REF#refs/tags/}" ]; then
|
||||
echo "Pushing tag: ${TAG}"
|
||||
git push github "refs/tags/${TAG}" --force
|
||||
else
|
||||
echo "Pushing branch: ${BRANCH}"
|
||||
git push github "HEAD:refs/heads/${BRANCH}" --force
|
||||
fi
|
||||
25
.gitignore
vendored
25
.gitignore
vendored
@@ -4,3 +4,28 @@
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
dev-debug.log
|
||||
# Dependency directories
|
||||
node_modules/
|
||||
# Environment variables
|
||||
.env
|
||||
# Editor directories and files
|
||||
.idea
|
||||
.vscode
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
# OS specific
|
||||
|
||||
# Taskmaster (local workflow tool)
|
||||
.taskmaster/
|
||||
.env.example
|
||||
|
||||
14
.gitleaks.toml
Normal file
14
.gitleaks.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[[allowlists]]
|
||||
description = "Pre-existing historical findings already on fj/main and github/main. The two PASTE_AUTH tokens in scripts/build.sh and scripts/build-linux-notify.sh are real — rotate if those endpoints still authenticate; this allowlist only silences the pre-push hook, it does not remove the exposure."
|
||||
commits = [
|
||||
# wzp-crypto module doc: false positive on "SHA-256(Ed25519 pub)[:16]"
|
||||
"51e893590c1b9fa49e9f6ae5c96c26deb58f353b",
|
||||
# build.sh PASTE_AUTH (paste.tbs.amn.gg)
|
||||
"bd6733b2e5d76b5259020f1c30a5223a9773b6aa",
|
||||
# build-linux-notify Authorization header (paste.dk.manko.yoga)
|
||||
"6d776097c83bc6fbe3f3565e080513d8af93b550",
|
||||
"7751439e2bca9eacf2c30929c8124a4eb6136df2",
|
||||
]
|
||||
4574
Cargo.lock
generated
4574
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
41
Cargo.toml
41
Cargo.toml
@@ -10,6 +10,9 @@ members = [
|
||||
"crates/wzp-client",
|
||||
"crates/wzp-web",
|
||||
"crates/wzp-android",
|
||||
"crates/wzp-native",
|
||||
"crates/wzp-video",
|
||||
"desktop/src-tauri",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -30,17 +33,25 @@ serde = { version = "1", features = ["derive"] }
|
||||
|
||||
# Transport
|
||||
quinn = "0.11"
|
||||
socket2 = "0.5"
|
||||
|
||||
# FEC
|
||||
raptorq = "2"
|
||||
|
||||
# Codec
|
||||
audiopus = "0.3.0-rc.0"
|
||||
# opusic-c: high-level safe bindings over libopus 1.5.2 (encoder side).
|
||||
# opusic-sys: raw FFI for the decoder side — we build our own DecoderHandle
|
||||
# because opusic-c::Decoder.inner is pub(crate) and cannot be reached for the
|
||||
# Phase 3 DRED reconstruction path. See docs/PRD-dred-integration.md.
|
||||
# Pinned exactly (no caret) for reproducible libopus 1.5.2 across the fleet.
|
||||
opusic-c = { version = "=1.5.5", default-features = false, features = ["bundled", "dred"] }
|
||||
opusic-sys = { version = "=0.6.0", default-features = false, features = ["bundled"] }
|
||||
bytemuck = "1"
|
||||
codec2 = "0.3"
|
||||
|
||||
# Crypto
|
||||
x25519-dalek = { version = "2", features = ["static_secrets"] }
|
||||
ed25519-dalek = { version = "2", features = ["rand_core"] }
|
||||
ed25519-dalek = { version = "2", features = ["rand_core", "pkcs8"] }
|
||||
chacha20poly1305 = "0.10"
|
||||
hkdf = "0.12"
|
||||
sha2 = "0.10"
|
||||
@@ -53,3 +64,29 @@ wzp-fec = { path = "crates/wzp-fec" }
|
||||
wzp-crypto = { path = "crates/wzp-crypto" }
|
||||
wzp-transport = { path = "crates/wzp-transport" }
|
||||
wzp-client = { path = "crates/wzp-client" }
|
||||
|
||||
# Fast dev profile: optimized but with debug info and incremental compilation.
|
||||
# Use with: cargo run --profile dev-fast
|
||||
[profile.dev-fast]
|
||||
inherits = "dev"
|
||||
opt-level = 2
|
||||
|
||||
# Optimize heavy compute deps even in debug builds —
|
||||
# real-time audio needs < 20ms per frame, impossible unoptimized.
|
||||
[profile.dev.package.nnnoiseless]
|
||||
opt-level = 3
|
||||
[profile.dev.package.opusic-sys]
|
||||
opt-level = 3
|
||||
[profile.dev.package.raptorq]
|
||||
opt-level = 3
|
||||
[profile.dev.package.wzp-codec]
|
||||
opt-level = 3
|
||||
[profile.dev.package.wzp-fec]
|
||||
opt-level = 3
|
||||
|
||||
# Phase 0 (opus-DRED): removed the [patch.crates-io] audiopus_sys = { path =
|
||||
# "vendor/audiopus_sys" } block. That patch existed to fix a Windows clang-cl
|
||||
# SIMD compile bug in libopus 1.3.1. With the swap to opusic-sys (libopus
|
||||
# 1.5.2), the upstream SIMD gating was fixed and the vendor patch is
|
||||
# obsolete. The vendor/audiopus_sys directory itself should be deleted as
|
||||
# part of the same cleanup — see the commit that follows this Phase 0.
|
||||
|
||||
@@ -29,5 +29,15 @@
|
||||
android:name="com.wzp.service.CallService"
|
||||
android:foregroundServiceType="microphone"
|
||||
android:exported="false" />
|
||||
|
||||
<provider
|
||||
android:name="androidx.core.content.FileProvider"
|
||||
android:authorities="${applicationId}.fileprovider"
|
||||
android:exported="false"
|
||||
android:grantUriPermissions="true">
|
||||
<meta-data
|
||||
android:name="android.support.FILE_PROVIDER_PATHS"
|
||||
android:resource="@xml/file_paths" />
|
||||
</provider>
|
||||
</application>
|
||||
</manifest>
|
||||
|
||||
@@ -8,10 +8,21 @@ import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.AudioTrack
|
||||
import android.media.MediaRecorder
|
||||
import android.media.audiofx.AcousticEchoCanceler
|
||||
import android.media.audiofx.NoiseSuppressor
|
||||
import android.util.Log
|
||||
import androidx.core.content.ContextCompat
|
||||
import com.wzp.engine.WzpEngine
|
||||
import java.io.BufferedOutputStream
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.OutputStreamWriter
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* Audio pipeline that captures mic audio and plays received audio using
|
||||
@@ -43,15 +54,38 @@ class AudioPipeline(private val context: Context) {
|
||||
/** Capture (mic) gain in dB. 0 = unity. */
|
||||
@Volatile
|
||||
var captureGainDb: Float = 0f
|
||||
/** Whether to attach hardware AEC. Must be set before start(). */
|
||||
var aecEnabled: Boolean = true
|
||||
/** Enable debug recording of PCM + RMS histogram to cache dir. */
|
||||
var debugRecording: Boolean = false
|
||||
private var captureThread: Thread? = null
|
||||
private var playoutThread: Thread? = null
|
||||
|
||||
// DirectByteBuffers for zero-copy JNI audio transfer.
|
||||
// Allocated as class fields (NOT locals) because ART's JIT OSR
|
||||
// can null local variables when it replaces the stack frame mid-loop.
|
||||
// These survive OSR because they're on the heap.
|
||||
private val captureDirectBuf: ByteBuffer =
|
||||
ByteBuffer.allocateDirect(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
private val playoutDirectBuf: ByteBuffer =
|
||||
ByteBuffer.allocateDirect(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
/** Latch counted down by each audio thread after exiting its loop.
|
||||
* stop() does NOT wait on this — teardown waits via awaitDrain(). */
|
||||
private var drainLatch: CountDownLatch? = null
|
||||
|
||||
private val debugDir: File by lazy {
|
||||
File(context.cacheDir, "wzp_debug").also { it.mkdirs() }
|
||||
}
|
||||
|
||||
fun start(engine: WzpEngine) {
|
||||
if (running) return
|
||||
running = true
|
||||
drainLatch = CountDownLatch(2) // one for capture, one for playout
|
||||
|
||||
captureThread = Thread({
|
||||
runCapture(engine)
|
||||
drainLatch?.countDown() // signal: capture loop exited, no more JNI calls
|
||||
// Park thread forever — exiting triggers a libcrypto TLS destructor
|
||||
// crash (SIGSEGV in OPENSSL_free) on Android when a JNI-calling thread exits.
|
||||
parkThread()
|
||||
@@ -63,6 +97,7 @@ class AudioPipeline(private val context: Context) {
|
||||
|
||||
playoutThread = Thread({
|
||||
runPlayout(engine)
|
||||
drainLatch?.countDown() // signal: playout loop exited
|
||||
parkThread()
|
||||
}, "wzp-playout").apply {
|
||||
isDaemon = true
|
||||
@@ -75,10 +110,20 @@ class AudioPipeline(private val context: Context) {
|
||||
|
||||
fun stop() {
|
||||
running = false
|
||||
// Don't join — threads are parked as daemons to avoid native TLS crash
|
||||
// Don't join threads — they are parked as daemons to avoid native TLS crash.
|
||||
// Don't null thread refs or drainLatch — teardown() needs awaitDrain().
|
||||
Log.i(TAG, "audio pipeline stopped (running=false)")
|
||||
}
|
||||
|
||||
/** Block until both audio threads have exited their loops (max 200ms).
|
||||
* After this returns, no more JNI calls to the engine will be made. */
|
||||
fun awaitDrain(): Boolean {
|
||||
val ok = drainLatch?.await(200, TimeUnit.MILLISECONDS) ?: true
|
||||
if (!ok) Log.w(TAG, "awaitDrain: audio threads did not drain in 200ms")
|
||||
captureThread = null
|
||||
playoutThread = null
|
||||
Log.i(TAG, "audio pipeline stopped")
|
||||
drainLatch = null
|
||||
return ok
|
||||
}
|
||||
|
||||
private fun applyGain(pcm: ShortArray, count: Int, db: Float) {
|
||||
@@ -89,6 +134,15 @@ class AudioPipeline(private val context: Context) {
|
||||
}
|
||||
}
|
||||
|
||||
private fun computeRms(pcm: ShortArray, count: Int): Int {
|
||||
var sumSq = 0.0
|
||||
for (i in 0 until count) {
|
||||
val s = pcm[i].toDouble()
|
||||
sumSq += s * s
|
||||
}
|
||||
return sqrt(sumSq / count).toInt()
|
||||
}
|
||||
|
||||
private fun parkThread() {
|
||||
try {
|
||||
Thread.sleep(Long.MAX_VALUE)
|
||||
@@ -127,25 +181,89 @@ class AudioPipeline(private val context: Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Attach hardware AEC if available and enabled in settings
|
||||
var aec: AcousticEchoCanceler? = null
|
||||
var ns: NoiseSuppressor? = null
|
||||
if (aecEnabled) {
|
||||
if (AcousticEchoCanceler.isAvailable()) {
|
||||
try {
|
||||
aec = AcousticEchoCanceler.create(recorder.audioSessionId)
|
||||
aec?.enabled = true
|
||||
Log.i(TAG, "AEC enabled (session=${recorder.audioSessionId})")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "AEC init failed: ${e.message}")
|
||||
}
|
||||
} else {
|
||||
Log.w(TAG, "AEC not available on this device")
|
||||
}
|
||||
|
||||
// Attach hardware noise suppressor if available
|
||||
if (NoiseSuppressor.isAvailable()) {
|
||||
try {
|
||||
ns = NoiseSuppressor.create(recorder.audioSessionId)
|
||||
ns?.enabled = true
|
||||
Log.i(TAG, "NoiseSuppressor enabled")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "NoiseSuppressor init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "AEC disabled by user setting")
|
||||
}
|
||||
|
||||
recorder.startRecording()
|
||||
Log.i(TAG, "capture started: ${SAMPLE_RATE}Hz mono, buf=$bufSize")
|
||||
Log.i(TAG, "capture started: ${SAMPLE_RATE}Hz mono, buf=$bufSize, aec=${aec?.enabled}, ns=${ns?.enabled}")
|
||||
|
||||
val pcm = ShortArray(FRAME_SAMPLES)
|
||||
// Debug: PCM file + RMS CSV
|
||||
var pcmOut: BufferedOutputStream? = null
|
||||
var rmsCsv: OutputStreamWriter? = null
|
||||
val byteConv = ByteBuffer.allocate(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
var frameIdx = 0L
|
||||
if (debugRecording) {
|
||||
try {
|
||||
pcmOut = BufferedOutputStream(FileOutputStream(File(debugDir, "capture.pcm")), 65536)
|
||||
rmsCsv = OutputStreamWriter(FileOutputStream(File(debugDir, "capture_rms.csv")))
|
||||
rmsCsv.write("frame,time_ms,rms\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "debug recording init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
try {
|
||||
while (running) {
|
||||
val read = recorder.read(pcm, 0, FRAME_SAMPLES)
|
||||
if (read > 0) {
|
||||
applyGain(pcm, read, captureGainDb)
|
||||
engine.writeAudio(pcm)
|
||||
// Zero-copy write via DirectByteBuffer (class field, survives JIT OSR)
|
||||
captureDirectBuf.clear()
|
||||
captureDirectBuf.asShortBuffer().put(pcm, 0, read)
|
||||
engine.writeAudioDirect(captureDirectBuf, read)
|
||||
|
||||
// Debug: write raw PCM + RMS
|
||||
if (pcmOut != null) {
|
||||
byteConv.clear()
|
||||
for (i in 0 until read) byteConv.putShort(pcm[i])
|
||||
pcmOut.write(byteConv.array(), 0, read * 2)
|
||||
}
|
||||
if (rmsCsv != null) {
|
||||
val rms = computeRms(pcm, read)
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,$rms\n")
|
||||
}
|
||||
frameIdx++
|
||||
} else if (read < 0) {
|
||||
Log.e(TAG, "AudioRecord.read error: $read")
|
||||
break
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
pcmOut?.close()
|
||||
rmsCsv?.close()
|
||||
recorder.stop()
|
||||
aec?.release()
|
||||
ns?.release()
|
||||
recorder.release()
|
||||
Log.i(TAG, "capture stopped")
|
||||
Log.i(TAG, "capture stopped (frames=$frameIdx)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,24 +299,61 @@ class AudioPipeline(private val context: Context) {
|
||||
Log.i(TAG, "playout started: ${SAMPLE_RATE}Hz mono, buf=$bufSize")
|
||||
|
||||
val pcm = ShortArray(FRAME_SAMPLES)
|
||||
val silence = ShortArray(FRAME_SAMPLES) // pre-allocated silence
|
||||
val silence = ShortArray(FRAME_SAMPLES)
|
||||
// Debug: PCM file + RMS CSV for playout
|
||||
var pcmOut: BufferedOutputStream? = null
|
||||
var rmsCsv: OutputStreamWriter? = null
|
||||
val byteConv = ByteBuffer.allocate(FRAME_SAMPLES * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
var frameIdx = 0L
|
||||
if (debugRecording) {
|
||||
try {
|
||||
pcmOut = BufferedOutputStream(FileOutputStream(File(debugDir, "playout.pcm")), 65536)
|
||||
rmsCsv = OutputStreamWriter(FileOutputStream(File(debugDir, "playout_rms.csv")))
|
||||
rmsCsv.write("frame,time_ms,rms\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "debug playout recording init failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
try {
|
||||
while (running) {
|
||||
val read = engine.readAudio(pcm)
|
||||
// Zero-copy read via DirectByteBuffer (class field, survives JIT OSR)
|
||||
playoutDirectBuf.clear()
|
||||
val read = engine.readAudioDirect(playoutDirectBuf, FRAME_SAMPLES)
|
||||
if (read >= FRAME_SAMPLES) {
|
||||
playoutDirectBuf.rewind()
|
||||
playoutDirectBuf.asShortBuffer().get(pcm, 0, read)
|
||||
applyGain(pcm, read, playoutGainDb)
|
||||
track.write(pcm, 0, read)
|
||||
|
||||
// Debug: write raw PCM + RMS
|
||||
if (pcmOut != null) {
|
||||
byteConv.clear()
|
||||
for (i in 0 until read) byteConv.putShort(pcm[i])
|
||||
pcmOut.write(byteConv.array(), 0, read * 2)
|
||||
}
|
||||
if (rmsCsv != null) {
|
||||
val rms = computeRms(pcm, read)
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,$rms\n")
|
||||
}
|
||||
frameIdx++
|
||||
} else {
|
||||
// Not enough decoded audio — write silence to keep stream alive
|
||||
track.write(silence, 0, FRAME_SAMPLES)
|
||||
// Sleep briefly to avoid busy-spinning
|
||||
// Log silence frames to RMS as 0
|
||||
if (rmsCsv != null) {
|
||||
val timeMs = frameIdx * FRAME_SAMPLES * 1000L / SAMPLE_RATE
|
||||
rmsCsv.write("$frameIdx,$timeMs,0\n")
|
||||
}
|
||||
frameIdx++
|
||||
Thread.sleep(5)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
pcmOut?.close()
|
||||
rmsCsv?.close()
|
||||
track.stop()
|
||||
track.release()
|
||||
Log.i(TAG, "playout stopped")
|
||||
Log.i(TAG, "playout stopped (frames=$frameIdx)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
203
android/app/src/main/java/com/wzp/data/SettingsRepository.kt
Normal file
203
android/app/src/main/java/com/wzp/data/SettingsRepository.kt
Normal file
@@ -0,0 +1,203 @@
|
||||
package com.wzp.data
|
||||
|
||||
import android.content.Context
|
||||
import android.content.SharedPreferences
|
||||
import com.wzp.ui.call.ServerEntry
|
||||
import org.json.JSONArray
|
||||
import org.json.JSONObject
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Persists user settings via SharedPreferences.
|
||||
*
|
||||
* Stores: servers, default server index, room name, alias, gain values,
|
||||
* IPv6 preference, and the identity seed (hex-encoded 32 bytes).
|
||||
*/
|
||||
class SettingsRepository(context: Context) {
|
||||
|
||||
private val prefs: SharedPreferences =
|
||||
context.applicationContext.getSharedPreferences("wzp_settings", Context.MODE_PRIVATE)
|
||||
|
||||
companion object {
|
||||
private const val KEY_SERVERS = "servers_json"
|
||||
private const val KEY_SELECTED_SERVER = "selected_server"
|
||||
private const val KEY_ROOM = "room_name"
|
||||
private const val KEY_ALIAS = "alias"
|
||||
private const val KEY_PLAYOUT_GAIN = "playout_gain_db"
|
||||
private const val KEY_CAPTURE_GAIN = "capture_gain_db"
|
||||
private const val KEY_PREFER_IPV6 = "prefer_ipv6"
|
||||
private const val KEY_IDENTITY_SEED = "identity_seed_hex"
|
||||
private const val KEY_AEC_ENABLED = "aec_enabled"
|
||||
private const val KEY_DEBUG_RECORDING = "debug_recording"
|
||||
private const val KEY_RECENT_ROOMS = "recent_rooms"
|
||||
private const val TOFU_PREFIX = "tofu_"
|
||||
}
|
||||
|
||||
// --- Servers ---
|
||||
|
||||
fun saveServers(servers: List<ServerEntry>) {
|
||||
val arr = JSONArray()
|
||||
servers.forEach { entry ->
|
||||
arr.put(JSONObject().apply {
|
||||
put("address", entry.address)
|
||||
put("label", entry.label)
|
||||
})
|
||||
}
|
||||
prefs.edit().putString(KEY_SERVERS, arr.toString()).apply()
|
||||
}
|
||||
|
||||
fun loadServers(): List<ServerEntry>? {
|
||||
val json = prefs.getString(KEY_SERVERS, null) ?: return null
|
||||
return try {
|
||||
val arr = JSONArray(json)
|
||||
(0 until arr.length()).map { i ->
|
||||
val obj = arr.getJSONObject(i)
|
||||
ServerEntry(obj.getString("address"), obj.getString("label"))
|
||||
}
|
||||
} catch (_: Exception) { null }
|
||||
}
|
||||
|
||||
fun saveSelectedServer(index: Int) {
|
||||
prefs.edit().putInt(KEY_SELECTED_SERVER, index).apply()
|
||||
}
|
||||
|
||||
fun loadSelectedServer(): Int = prefs.getInt(KEY_SELECTED_SERVER, 0)
|
||||
|
||||
// --- Room ---
|
||||
|
||||
fun saveRoom(name: String) { prefs.edit().putString(KEY_ROOM, name).apply() }
|
||||
fun loadRoom(): String = prefs.getString(KEY_ROOM, "android") ?: "android"
|
||||
|
||||
// --- Alias ---
|
||||
|
||||
fun saveAlias(alias: String) { prefs.edit().putString(KEY_ALIAS, alias).apply() }
|
||||
|
||||
/**
|
||||
* Load alias, generating a random name on first launch.
|
||||
*/
|
||||
fun getOrCreateAlias(): String {
|
||||
val existing = prefs.getString(KEY_ALIAS, null)
|
||||
if (!existing.isNullOrEmpty()) return existing
|
||||
val name = generateRandomName()
|
||||
prefs.edit().putString(KEY_ALIAS, name).apply()
|
||||
return name
|
||||
}
|
||||
|
||||
private fun generateRandomName(): String {
|
||||
val adjectives = listOf(
|
||||
"Swift", "Silent", "Brave", "Calm", "Dark", "Fierce", "Ghost",
|
||||
"Iron", "Lucky", "Noble", "Quick", "Sharp", "Storm", "Wild",
|
||||
"Cold", "Bright", "Lone", "Red", "Grey", "Frosty", "Dusty",
|
||||
"Rusty", "Neon", "Void", "Solar", "Lunar", "Cyber", "Pixel",
|
||||
"Sonic", "Hyper", "Turbo", "Nano", "Mega", "Ultra", "Zinc"
|
||||
)
|
||||
val nouns = listOf(
|
||||
"Wolf", "Hawk", "Fox", "Bear", "Lynx", "Crow", "Viper",
|
||||
"Cobra", "Tiger", "Eagle", "Shark", "Raven", "Falcon", "Otter",
|
||||
"Mantis", "Panda", "Jackal", "Badger", "Heron", "Bison",
|
||||
"Condor", "Coyote", "Gecko", "Hornet", "Marten", "Osprey",
|
||||
"Parrot", "Puma", "Raptor", "Stork", "Toucan", "Walrus"
|
||||
)
|
||||
val adj = adjectives.random()
|
||||
val noun = nouns.random()
|
||||
return "$adj $noun"
|
||||
}
|
||||
|
||||
// --- Gain ---
|
||||
|
||||
fun savePlayoutGain(db: Float) { prefs.edit().putFloat(KEY_PLAYOUT_GAIN, db).apply() }
|
||||
fun loadPlayoutGain(): Float = prefs.getFloat(KEY_PLAYOUT_GAIN, 0f)
|
||||
|
||||
fun saveCaptureGain(db: Float) { prefs.edit().putFloat(KEY_CAPTURE_GAIN, db).apply() }
|
||||
fun loadCaptureGain(): Float = prefs.getFloat(KEY_CAPTURE_GAIN, 0f)
|
||||
|
||||
// --- IPv6 ---
|
||||
|
||||
fun savePreferIPv6(prefer: Boolean) { prefs.edit().putBoolean(KEY_PREFER_IPV6, prefer).apply() }
|
||||
fun loadPreferIPv6(): Boolean = prefs.getBoolean(KEY_PREFER_IPV6, false)
|
||||
|
||||
// --- AEC ---
|
||||
|
||||
fun saveAecEnabled(enabled: Boolean) { prefs.edit().putBoolean(KEY_AEC_ENABLED, enabled).apply() }
|
||||
fun loadAecEnabled(): Boolean = prefs.getBoolean(KEY_AEC_ENABLED, true)
|
||||
|
||||
// --- Debug recording ---
|
||||
|
||||
fun saveDebugRecording(enabled: Boolean) { prefs.edit().putBoolean(KEY_DEBUG_RECORDING, enabled).apply() }
|
||||
fun loadDebugRecording(): Boolean = prefs.getBoolean(KEY_DEBUG_RECORDING, false)
|
||||
|
||||
// --- Codec choice ---
|
||||
// 0 = Opus (GOOD), 1 = Opus Low (DEGRADED), 2 = Codec2 (CATASTROPHIC)
|
||||
fun saveCodecChoice(choice: Int) { prefs.edit().putInt("codec_choice", choice).apply() }
|
||||
fun loadCodecChoice(): Int = prefs.getInt("codec_choice", 0)
|
||||
|
||||
// --- Identity seed ---
|
||||
|
||||
/**
|
||||
* Get or generate the identity seed. On first call, generates a random
|
||||
* 32-byte seed and persists it. Subsequent calls return the same seed.
|
||||
*/
|
||||
fun getOrCreateSeedHex(): String {
|
||||
val existing = prefs.getString(KEY_IDENTITY_SEED, null)
|
||||
if (!existing.isNullOrEmpty()) return existing
|
||||
val seed = ByteArray(32).also { SecureRandom().nextBytes(it) }
|
||||
val hex = seed.joinToString("") { "%02x".format(it) }
|
||||
prefs.edit().putString(KEY_IDENTITY_SEED, hex).apply()
|
||||
return hex
|
||||
}
|
||||
|
||||
fun loadSeedHex(): String = prefs.getString(KEY_IDENTITY_SEED, "") ?: ""
|
||||
|
||||
fun saveSeedHex(hex: String) {
|
||||
prefs.edit().putString(KEY_IDENTITY_SEED, hex).apply()
|
||||
}
|
||||
|
||||
// --- Recent rooms ---
|
||||
|
||||
data class RecentRoom(val relay: String, val room: String)
|
||||
|
||||
fun addRecentRoom(relay: String, room: String) {
|
||||
val rooms = loadRecentRooms().toMutableList()
|
||||
rooms.removeAll { it.relay == relay && it.room == room }
|
||||
rooms.add(0, RecentRoom(relay, room))
|
||||
if (rooms.size > 5) rooms.subList(5, rooms.size).clear()
|
||||
val arr = JSONArray()
|
||||
rooms.forEach { arr.put(JSONObject().apply { put("relay", it.relay); put("room", it.room) }) }
|
||||
prefs.edit().putString(KEY_RECENT_ROOMS, arr.toString()).apply()
|
||||
}
|
||||
|
||||
fun loadRecentRooms(): List<RecentRoom> {
|
||||
val json = prefs.getString(KEY_RECENT_ROOMS, null) ?: return emptyList()
|
||||
return try {
|
||||
val arr = JSONArray(json)
|
||||
(0 until arr.length()).map { i ->
|
||||
val o = arr.getJSONObject(i)
|
||||
RecentRoom(o.getString("relay"), o.getString("room"))
|
||||
}
|
||||
} catch (_: Exception) { emptyList() }
|
||||
}
|
||||
|
||||
fun clearRecentRooms() {
|
||||
prefs.edit().remove(KEY_RECENT_ROOMS).apply()
|
||||
}
|
||||
|
||||
// --- Server fingerprint TOFU ---
|
||||
|
||||
fun saveServerFingerprint(address: String, fingerprint: String) {
|
||||
prefs.edit().putString("$TOFU_PREFIX$address", fingerprint).apply()
|
||||
}
|
||||
|
||||
fun loadServerFingerprint(address: String): String? {
|
||||
return prefs.getString("$TOFU_PREFIX$address", null)
|
||||
}
|
||||
|
||||
// --- Ping RTT cache ---
|
||||
|
||||
fun savePingRtt(address: String, rttMs: Int) {
|
||||
prefs.edit().putInt("ping_rtt_$address", rttMs).apply()
|
||||
}
|
||||
|
||||
fun loadPingRtt(address: String): Int {
|
||||
return prefs.getInt("ping_rtt_$address", -1)
|
||||
}
|
||||
}
|
||||
242
android/app/src/main/java/com/wzp/debug/DebugReporter.kt
Normal file
242
android/app/src/main/java/com/wzp/debug/DebugReporter.kt
Normal file
@@ -0,0 +1,242 @@
|
||||
package com.wzp.debug
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.BufferedOutputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.File
|
||||
import java.io.FileInputStream
|
||||
import java.io.FileOutputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import java.util.zip.ZipEntry
|
||||
import java.util.zip.ZipOutputStream
|
||||
|
||||
/**
|
||||
* Collects call debug data (audio recordings, logs, histograms, stats)
|
||||
* into a zip file for email sharing.
|
||||
*/
|
||||
class DebugReporter(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "DebugReporter"
|
||||
private const val SAMPLE_RATE = 48000
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a zip with all debug data.
|
||||
* Returns the zip File on success, or null on failure.
|
||||
*/
|
||||
suspend fun collectZip(
|
||||
callDurationSecs: Double,
|
||||
finalStatsJson: String,
|
||||
aecEnabled: Boolean,
|
||||
alias: String,
|
||||
server: String,
|
||||
room: String
|
||||
): File? = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val debugDir = File(context.cacheDir, "wzp_debug")
|
||||
val timestamp = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.US).format(Date())
|
||||
val zipFile = File(context.cacheDir, "wzp_debug_${timestamp}.zip")
|
||||
|
||||
ZipOutputStream(BufferedOutputStream(FileOutputStream(zipFile))).use { zos ->
|
||||
// Phase 4: extract DRED / classical PLC counters from the
|
||||
// stats JSON so they're visible in the meta preamble at a
|
||||
// glance, not buried in the trailing JSON dump.
|
||||
val dredReconstructions = extractLongField(finalStatsJson, "dred_reconstructions")
|
||||
val classicalPlc = extractLongField(finalStatsJson, "classical_plc_invocations")
|
||||
val framesDecoded = extractLongField(finalStatsJson, "frames_decoded")
|
||||
val fecRecovered = extractLongField(finalStatsJson, "fec_recovered")
|
||||
|
||||
// 1. Call metadata
|
||||
val meta = buildString {
|
||||
appendLine("=== WZ Phone Debug Report ===")
|
||||
appendLine("Timestamp: $timestamp")
|
||||
appendLine("Alias: $alias")
|
||||
appendLine("Server: $server")
|
||||
appendLine("Room: $room")
|
||||
appendLine("Duration: ${"%.1f".format(callDurationSecs)}s")
|
||||
appendLine("AEC: ${if (aecEnabled) "ON" else "OFF"}")
|
||||
appendLine("Device: ${android.os.Build.MANUFACTURER} ${android.os.Build.MODEL}")
|
||||
appendLine("Android: ${android.os.Build.VERSION.RELEASE} (API ${android.os.Build.VERSION.SDK_INT})")
|
||||
appendLine()
|
||||
appendLine("=== Loss Recovery ===")
|
||||
appendLine("Frames decoded: $framesDecoded")
|
||||
appendLine("DRED reconstructions: $dredReconstructions (Opus neural recovery)")
|
||||
appendLine("Classical PLC: $classicalPlc (fallback)")
|
||||
appendLine("RaptorQ FEC recovered: $fecRecovered (Codec2 only)")
|
||||
if (framesDecoded > 0) {
|
||||
val dredPct = 100.0 * dredReconstructions / framesDecoded
|
||||
val plcPct = 100.0 * classicalPlc / framesDecoded
|
||||
appendLine("DRED rate: ${"%.2f".format(dredPct)}%")
|
||||
appendLine("Classical PLC rate: ${"%.2f".format(plcPct)}%")
|
||||
}
|
||||
appendLine()
|
||||
appendLine("=== Final Stats ===")
|
||||
appendLine(finalStatsJson)
|
||||
}
|
||||
addTextEntry(zos, "meta.txt", meta)
|
||||
|
||||
// 2. Logcat — WZP-related tags
|
||||
val logcat = collectLogcat()
|
||||
addTextEntry(zos, "logcat.txt", logcat)
|
||||
|
||||
// 3. Capture audio (mic) → WAV
|
||||
val captureRaw = File(debugDir, "capture.pcm")
|
||||
if (captureRaw.exists() && captureRaw.length() > 0) {
|
||||
addWavEntry(zos, "capture.wav", captureRaw)
|
||||
Log.i(TAG, "capture.pcm: ${captureRaw.length()} bytes -> WAV")
|
||||
}
|
||||
|
||||
// 4. Playout audio (speaker) → WAV
|
||||
val playoutRaw = File(debugDir, "playout.pcm")
|
||||
if (playoutRaw.exists() && playoutRaw.length() > 0) {
|
||||
addWavEntry(zos, "playout.wav", playoutRaw)
|
||||
Log.i(TAG, "playout.pcm: ${playoutRaw.length()} bytes -> WAV")
|
||||
}
|
||||
|
||||
// 5. RMS histogram CSV
|
||||
val captureHist = File(debugDir, "capture_rms.csv")
|
||||
if (captureHist.exists()) addFileEntry(zos, "capture_rms.csv", captureHist)
|
||||
val playoutHist = File(debugDir, "playout_rms.csv")
|
||||
if (playoutHist.exists()) addFileEntry(zos, "playout_rms.csv", playoutHist)
|
||||
}
|
||||
|
||||
Log.i(TAG, "zip created: ${zipFile.length()} bytes (${zipFile.length() / 1024}KB)")
|
||||
|
||||
// Clean up raw debug files (keep zip)
|
||||
debugDir.listFiles()?.forEach { it.delete() }
|
||||
|
||||
zipFile
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "debug report failed", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/** Clean up any leftover debug files from a previous session. */
|
||||
fun prepareForCall() {
|
||||
val debugDir = File(context.cacheDir, "wzp_debug")
|
||||
if (debugDir.exists()) {
|
||||
debugDir.listFiles()?.forEach { it.delete() }
|
||||
}
|
||||
debugDir.mkdirs()
|
||||
// Also clean up old zip files
|
||||
context.cacheDir.listFiles()?.filter { it.name.startsWith("wzp_debug_") }?.forEach { it.delete() }
|
||||
}
|
||||
|
||||
private fun collectLogcat(): String {
|
||||
return try {
|
||||
val process = Runtime.getRuntime().exec(
|
||||
arrayOf(
|
||||
"logcat", "-d",
|
||||
"-t", "5000",
|
||||
"--format", "threadtime"
|
||||
)
|
||||
)
|
||||
val output = process.inputStream.bufferedReader().readText()
|
||||
process.waitFor()
|
||||
output.lines()
|
||||
.filter { line ->
|
||||
line.contains("wzp", ignoreCase = true) ||
|
||||
line.contains("WzpEngine") ||
|
||||
line.contains("AudioPipeline") ||
|
||||
line.contains("WzpCall") ||
|
||||
line.contains("CallService") ||
|
||||
line.contains("AudioTrack") ||
|
||||
line.contains("AudioRecord") ||
|
||||
line.contains("AcousticEchoCanceler") ||
|
||||
line.contains("NoiseSuppressor") ||
|
||||
line.contains("FATAL") ||
|
||||
line.contains("ANR") ||
|
||||
line.contains("AudioFlinger") ||
|
||||
line.contains("DebugReporter") ||
|
||||
line.contains("QUIC") ||
|
||||
line.contains("quinn") ||
|
||||
line.contains("send task") ||
|
||||
line.contains("recv task") ||
|
||||
line.contains("send stats") ||
|
||||
line.contains("recv stats") ||
|
||||
line.contains("send_media") ||
|
||||
line.contains("FEC block") ||
|
||||
line.contains("recv gap") ||
|
||||
line.contains("frames_dropped") ||
|
||||
line.contains("opus")
|
||||
}
|
||||
.joinToString("\n")
|
||||
} catch (e: Exception) {
|
||||
"Failed to collect logcat: ${e.message}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun addWavEntry(zos: ZipOutputStream, name: String, pcmFile: File) {
|
||||
val dataSize = pcmFile.length().toInt()
|
||||
val byteRate = SAMPLE_RATE * 1 * 16 / 8
|
||||
val blockAlign = 1 * 16 / 8
|
||||
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
|
||||
// Write WAV header (44 bytes)
|
||||
val header = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
header.put("RIFF".toByteArray())
|
||||
header.putInt(36 + dataSize)
|
||||
header.put("WAVE".toByteArray())
|
||||
header.put("fmt ".toByteArray())
|
||||
header.putInt(16)
|
||||
header.putShort(1) // PCM
|
||||
header.putShort(1) // mono
|
||||
header.putInt(SAMPLE_RATE)
|
||||
header.putInt(byteRate)
|
||||
header.putShort(blockAlign.toShort())
|
||||
header.putShort(16) // bits per sample
|
||||
header.put("data".toByteArray())
|
||||
header.putInt(dataSize)
|
||||
zos.write(header.array())
|
||||
|
||||
// Stream PCM data directly (avoids loading entire file into memory)
|
||||
FileInputStream(pcmFile).use { it.copyTo(zos) }
|
||||
zos.closeEntry()
|
||||
}
|
||||
|
||||
private fun addTextEntry(zos: ZipOutputStream, name: String, content: String) {
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
zos.write(content.toByteArray())
|
||||
zos.closeEntry()
|
||||
}
|
||||
|
||||
private fun addFileEntry(zos: ZipOutputStream, name: String, file: File) {
|
||||
zos.putNextEntry(ZipEntry(name))
|
||||
FileInputStream(file).use { it.copyTo(zos) }
|
||||
zos.closeEntry()
|
||||
}
|
||||
|
||||
/**
|
||||
* Tiny JSON field extractor — pulls an integer value for a top-level
|
||||
* field like `"dred_reconstructions":42`. We don't want to pull in a
|
||||
* full JSON parser just for the debug preamble, and the CallStats
|
||||
* output is a flat record with well-known field names.
|
||||
*
|
||||
* Returns 0 if the field is missing or unparseable.
|
||||
*/
|
||||
private fun extractLongField(json: String, field: String): Long {
|
||||
val key = "\"$field\":"
|
||||
val idx = json.indexOf(key)
|
||||
if (idx < 0) return 0
|
||||
var i = idx + key.length
|
||||
// Skip whitespace
|
||||
while (i < json.length && json[i].isWhitespace()) i++
|
||||
val start = i
|
||||
while (i < json.length && (json[i].isDigit() || json[i] == '-')) i++
|
||||
return try {
|
||||
json.substring(start, i).toLong()
|
||||
} catch (_: NumberFormatException) {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -33,10 +33,24 @@ data class CallStats(
|
||||
val fecRecovered: Long = 0,
|
||||
/** Current mic audio level (RMS, 0-32767). */
|
||||
val audioLevel: Int = 0,
|
||||
/** Our current outgoing codec (e.g. "Opus24k"). */
|
||||
val currentCodec: String = "",
|
||||
/** Last seen incoming codec from peers. */
|
||||
val peerCodec: String = "",
|
||||
/** Whether auto quality mode is active. */
|
||||
val autoMode: Boolean = false,
|
||||
/** Number of participants in the room. */
|
||||
val roomParticipantCount: Int = 0,
|
||||
/** Participants in the room (fingerprint + optional alias). */
|
||||
val roomParticipants: List<RoomMember> = emptyList(),
|
||||
/** SAS verification code (4-digit, null if not in a call). */
|
||||
val sasCode: Int? = null,
|
||||
/** Incoming call ID (or "relay|room" for CallSetup). */
|
||||
val incomingCallId: String? = null,
|
||||
/** Incoming caller's fingerprint. */
|
||||
val incomingCallerFp: String? = null,
|
||||
/** Incoming caller's alias. */
|
||||
val incomingCallerAlias: String? = null,
|
||||
) {
|
||||
/** Human-readable quality label. */
|
||||
val qualityLabel: String
|
||||
@@ -54,7 +68,8 @@ data class CallStats(
|
||||
val o = arr.getJSONObject(i)
|
||||
RoomMember(
|
||||
fingerprint = o.optString("fingerprint", ""),
|
||||
alias = o.optString("alias", null)
|
||||
alias = if (o.isNull("alias")) null else o.optString("alias", null),
|
||||
relayLabel = if (o.isNull("relay_label")) null else o.optString("relay_label", null)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -76,8 +91,15 @@ data class CallStats(
|
||||
underruns = obj.optLong("underruns", 0),
|
||||
fecRecovered = obj.optLong("fec_recovered", 0),
|
||||
audioLevel = obj.optInt("audio_level", 0),
|
||||
currentCodec = obj.optString("current_codec", ""),
|
||||
peerCodec = obj.optString("peer_codec", ""),
|
||||
autoMode = obj.optBoolean("auto_mode", false),
|
||||
roomParticipantCount = obj.optInt("room_participant_count", 0),
|
||||
roomParticipants = parseParticipants(obj.optJSONArray("room_participants"))
|
||||
roomParticipants = parseParticipants(obj.optJSONArray("room_participants")),
|
||||
sasCode = if (obj.has("sas_code")) obj.optInt("sas_code") else null,
|
||||
incomingCallId = if (obj.isNull("incoming_call_id")) null else obj.optString("incoming_call_id", null),
|
||||
incomingCallerFp = if (obj.isNull("incoming_caller_fp")) null else obj.optString("incoming_caller_fp", null),
|
||||
incomingCallerAlias = if (obj.isNull("incoming_caller_alias")) null else obj.optString("incoming_caller_alias", null),
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
CallStats()
|
||||
@@ -88,9 +110,11 @@ data class CallStats(
|
||||
|
||||
data class RoomMember(
|
||||
val fingerprint: String,
|
||||
val alias: String? = null
|
||||
val alias: String? = null,
|
||||
val relayLabel: String? = null
|
||||
) {
|
||||
/** Short display name: alias if set, otherwise first 8 chars of fingerprint. */
|
||||
val displayName: String
|
||||
get() = alias ?: fingerprint.take(8)
|
||||
get() = alias?.takeIf { it.isNotBlank() }
|
||||
?: fingerprint.take(8).ifEmpty { "unknown" }
|
||||
}
|
||||
|
||||
@@ -35,11 +35,15 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
* @param room room identifier (used as QUIC SNI)
|
||||
* @param seedHex 64-char hex-encoded 32-byte identity seed (empty = random)
|
||||
* @param token authentication token (empty = no auth)
|
||||
* @param alias display name sent to relay for room participant list
|
||||
* @return 0 on success, negative error code on failure
|
||||
*/
|
||||
fun startCall(relayAddr: String, room: String, seedHex: String = "", token: String = ""): Int {
|
||||
/**
|
||||
* @param profile 0 = Opus GOOD, 1 = Opus DEGRADED, 2 = Codec2 CATASTROPHIC
|
||||
*/
|
||||
fun startCall(relayAddr: String, room: String, seedHex: String = "", token: String = "", alias: String = "", profile: Int = 0): Int {
|
||||
check(nativeHandle != 0L) { "Engine not initialized" }
|
||||
val result = nativeStartCall(nativeHandle, relayAddr, room, seedHex, token)
|
||||
val result = nativeStartCall(nativeHandle, relayAddr, room, seedHex, token, alias, profile)
|
||||
if (result == 0) {
|
||||
callback.onCallStateChanged(CallStateConstants.CONNECTING)
|
||||
} else {
|
||||
@@ -49,6 +53,7 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
}
|
||||
|
||||
/** Stop the active call. Safe to call when no call is active. */
|
||||
@Synchronized
|
||||
fun stopCall() {
|
||||
if (nativeHandle != 0L) {
|
||||
nativeStopCall(nativeHandle)
|
||||
@@ -72,6 +77,7 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
*
|
||||
* @return JSON-serialised [CallStats], or `"{}"` if the engine is not initialised.
|
||||
*/
|
||||
@Synchronized
|
||||
fun getStats(): String {
|
||||
if (nativeHandle == 0L) return "{}"
|
||||
return try {
|
||||
@@ -90,7 +96,19 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
if (nativeHandle != 0L) nativeForceProfile(nativeHandle, profile)
|
||||
}
|
||||
|
||||
/**
|
||||
* Signal a network transport change (e.g. WiFi → LTE handoff).
|
||||
*
|
||||
* @param networkType matches Rust `NetworkContext` ordinals:
|
||||
* 0=WiFi, 1=LTE, 2=5G, 3=3G, 4=Unknown, 5=None
|
||||
* @param bandwidthKbps reported downstream bandwidth in kbps
|
||||
*/
|
||||
fun onNetworkChanged(networkType: Int, bandwidthKbps: Int) {
|
||||
if (nativeHandle != 0L) nativeOnNetworkChanged(nativeHandle, networkType, bandwidthKbps)
|
||||
}
|
||||
|
||||
/** Destroy the native engine and free all resources. The instance must not be reused. */
|
||||
@Synchronized
|
||||
fun destroy() {
|
||||
if (nativeHandle != 0L) {
|
||||
nativeDestroy(nativeHandle)
|
||||
@@ -116,11 +134,31 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
return nativeReadAudio(nativeHandle, pcm)
|
||||
}
|
||||
|
||||
/**
|
||||
* Write captured PCM from a DirectByteBuffer — zero JNI array copy.
|
||||
* The buffer must be a direct ByteBuffer with native byte order containing i16 samples.
|
||||
* Called from the AudioRecord capture thread.
|
||||
*/
|
||||
fun writeAudioDirect(buffer: java.nio.ByteBuffer, sampleCount: Int): Int {
|
||||
if (nativeHandle == 0L) return 0
|
||||
return nativeWriteAudioDirect(nativeHandle, buffer, sampleCount)
|
||||
}
|
||||
|
||||
/**
|
||||
* Read decoded PCM into a DirectByteBuffer — zero JNI array copy.
|
||||
* The buffer must be a direct ByteBuffer with native byte order.
|
||||
* Called from the AudioTrack playout thread.
|
||||
*/
|
||||
fun readAudioDirect(buffer: java.nio.ByteBuffer, maxSamples: Int): Int {
|
||||
if (nativeHandle == 0L) return 0
|
||||
return nativeReadAudioDirect(nativeHandle, buffer, maxSamples)
|
||||
}
|
||||
|
||||
// -- JNI native methods --------------------------------------------------
|
||||
|
||||
private external fun nativeInit(): Long
|
||||
private external fun nativeStartCall(
|
||||
handle: Long, relay: String, room: String, seed: String, token: String
|
||||
handle: Long, relay: String, room: String, seed: String, token: String, alias: String, profile: Int
|
||||
): Int
|
||||
private external fun nativeStopCall(handle: Long)
|
||||
private external fun nativeSetMute(handle: Long, muted: Boolean)
|
||||
@@ -129,7 +167,58 @@ class WzpEngine(private val callback: WzpCallback) {
|
||||
private external fun nativeForceProfile(handle: Long, profile: Int)
|
||||
private external fun nativeWriteAudio(handle: Long, pcm: ShortArray): Int
|
||||
private external fun nativeReadAudio(handle: Long, pcm: ShortArray): Int
|
||||
private external fun nativeWriteAudioDirect(handle: Long, buffer: java.nio.ByteBuffer, sampleCount: Int): Int
|
||||
private external fun nativeReadAudioDirect(handle: Long, buffer: java.nio.ByteBuffer, maxSamples: Int): Int
|
||||
private external fun nativeDestroy(handle: Long)
|
||||
private external fun nativePingRelay(handle: Long, relay: String): String?
|
||||
private external fun nativeStartSignaling(handle: Long, relay: String, seed: String, token: String, alias: String): Int
|
||||
private external fun nativePlaceCall(handle: Long, targetFp: String): Int
|
||||
private external fun nativeAnswerCall(handle: Long, callId: String, mode: Int): Int
|
||||
private external fun nativeOnNetworkChanged(handle: Long, networkType: Int, bandwidthKbps: Int)
|
||||
|
||||
/**
|
||||
* Ping a relay server. Requires engine to be initialized.
|
||||
* Returns JSON `{"rtt_ms":N,"server_fingerprint":"hex"}` or null.
|
||||
*/
|
||||
fun pingRelay(address: String): String? {
|
||||
if (nativeHandle == 0L) return null
|
||||
return nativePingRelay(nativeHandle, address)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start persistent signaling connection for direct 1:1 calls.
|
||||
* The engine registers on the relay and listens for incoming calls.
|
||||
* Call state updates are available via [getStats].
|
||||
*
|
||||
* @return 0 on success, -1 on error
|
||||
*/
|
||||
fun startSignaling(relay: String, seed: String = "", token: String = "", alias: String = ""): Int {
|
||||
check(nativeHandle != 0L) { "Engine not initialized" }
|
||||
return nativeStartSignaling(nativeHandle, relay, seed, token, alias)
|
||||
}
|
||||
|
||||
/**
|
||||
* Place a direct call to a peer by fingerprint.
|
||||
* Requires [startSignaling] to have been called first.
|
||||
*
|
||||
* @return 0 on success, -1 on error
|
||||
*/
|
||||
fun placeCall(targetFingerprint: String): Int {
|
||||
check(nativeHandle != 0L) { "Engine not initialized" }
|
||||
return nativePlaceCall(nativeHandle, targetFingerprint)
|
||||
}
|
||||
|
||||
/**
|
||||
* Answer an incoming direct call.
|
||||
*
|
||||
* @param callId The call ID from the incoming call (available in stats.incoming_call_id)
|
||||
* @param mode 0=Reject, 1=AcceptTrusted (P2P in Phase 2), 2=AcceptGeneric (relay-mediated)
|
||||
* @return 0 on success, -1 on error
|
||||
*/
|
||||
fun answerCall(callId: String, mode: Int = 2): Int {
|
||||
check(nativeHandle != 0L) { "Engine not initialized" }
|
||||
return nativeAnswerCall(nativeHandle, callId, mode)
|
||||
}
|
||||
|
||||
companion object {
|
||||
init {
|
||||
|
||||
141
android/app/src/main/java/com/wzp/net/NetworkMonitor.kt
Normal file
141
android/app/src/main/java/com/wzp/net/NetworkMonitor.kt
Normal file
@@ -0,0 +1,141 @@
|
||||
package com.wzp.net
|
||||
|
||||
import android.content.Context
|
||||
import android.net.ConnectivityManager
|
||||
import android.net.Network
|
||||
import android.net.NetworkCapabilities
|
||||
import android.net.NetworkRequest
|
||||
import android.os.Handler
|
||||
import android.os.Looper
|
||||
|
||||
/**
|
||||
* Monitors network connectivity changes via [ConnectivityManager.NetworkCallback]
|
||||
* and classifies the active transport (WiFi, LTE, 5G, 3G).
|
||||
*
|
||||
* Callbacks fire on the main looper so callers can safely update UI state or
|
||||
* dispatch to a native engine from any callback.
|
||||
*
|
||||
* Usage:
|
||||
* 1. Set [onNetworkChanged] to receive `(type: Int, downlinkKbps: Int)` events
|
||||
* 2. Optionally set [onIpChanged] for IP address change events (mid-call ICE refresh)
|
||||
* 3. Call [register] when the call starts
|
||||
* 4. Call [unregister] when the call ends
|
||||
*/
|
||||
class NetworkMonitor(context: Context) {
|
||||
|
||||
private val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
|
||||
private val mainHandler = Handler(Looper.getMainLooper())
|
||||
|
||||
/**
|
||||
* Called when the network transport type or bandwidth changes.
|
||||
* `type` constants match the Rust `NetworkContext` enum ordinals.
|
||||
*/
|
||||
var onNetworkChanged: ((type: Int, downlinkKbps: Int) -> Unit)? = null
|
||||
|
||||
/**
|
||||
* Called when the device's IP address changes (link properties changed).
|
||||
* Useful for triggering mid-call ICE candidate re-gathering.
|
||||
*/
|
||||
var onIpChanged: (() -> Unit)? = null
|
||||
|
||||
// Track the last emitted type to avoid redundant callbacks
|
||||
@Volatile
|
||||
private var lastEmittedType: Int = TYPE_UNKNOWN
|
||||
|
||||
private val callback = object : ConnectivityManager.NetworkCallback() {
|
||||
override fun onAvailable(network: Network) {
|
||||
classifyAndEmit(network)
|
||||
}
|
||||
|
||||
override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
|
||||
classifyFromCaps(caps)
|
||||
}
|
||||
|
||||
override fun onLinkPropertiesChanged(
|
||||
network: Network,
|
||||
linkProperties: android.net.LinkProperties
|
||||
) {
|
||||
// IP address may have changed — notify for ICE refresh
|
||||
onIpChanged?.invoke()
|
||||
// Also re-classify in case the transport changed simultaneously
|
||||
classifyAndEmit(network)
|
||||
}
|
||||
|
||||
override fun onLost(network: Network) {
|
||||
lastEmittedType = TYPE_NONE
|
||||
onNetworkChanged?.invoke(TYPE_NONE, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// -- Public API -----------------------------------------------------------
|
||||
|
||||
/** Register the network callback. Call when a call starts. */
|
||||
fun register() {
|
||||
val request = NetworkRequest.Builder()
|
||||
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
|
||||
.build()
|
||||
cm.registerNetworkCallback(request, callback, mainHandler)
|
||||
}
|
||||
|
||||
/** Unregister the network callback. Call when the call ends. */
|
||||
fun unregister() {
|
||||
try {
|
||||
cm.unregisterNetworkCallback(callback)
|
||||
} catch (_: IllegalArgumentException) {
|
||||
// Already unregistered — safe to ignore
|
||||
}
|
||||
}
|
||||
|
||||
// -- Classification -------------------------------------------------------
|
||||
|
||||
private fun classifyAndEmit(network: Network) {
|
||||
val caps = cm.getNetworkCapabilities(network) ?: return
|
||||
classifyFromCaps(caps)
|
||||
}
|
||||
|
||||
private fun classifyFromCaps(caps: NetworkCapabilities) {
|
||||
val type = when {
|
||||
caps.hasTransport(NetworkCapabilities.TRANSPORT_WIFI) -> TYPE_WIFI
|
||||
caps.hasTransport(NetworkCapabilities.TRANSPORT_ETHERNET) -> TYPE_WIFI // treat as WiFi
|
||||
caps.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR) -> classifyCellular(caps)
|
||||
else -> TYPE_UNKNOWN
|
||||
}
|
||||
val bw = caps.getLinkDownstreamBandwidthKbps()
|
||||
|
||||
// Deduplicate: only emit when the transport type actually changes
|
||||
if (type != lastEmittedType) {
|
||||
lastEmittedType = type
|
||||
onNetworkChanged?.invoke(type, bw)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Approximate cellular generation from reported downstream bandwidth.
|
||||
* This avoids requiring READ_PHONE_STATE permission (needed for
|
||||
* TelephonyManager.getNetworkType on API 30+).
|
||||
*
|
||||
* Thresholds are conservative — carriers over-report bandwidth, so we
|
||||
* classify based on what's actually usable for VoIP:
|
||||
* - >= 100 Mbps → 5G NR
|
||||
* - >= 10 Mbps → LTE
|
||||
* - < 10 Mbps → 3G or worse
|
||||
*/
|
||||
private fun classifyCellular(caps: NetworkCapabilities): Int {
|
||||
val bw = caps.getLinkDownstreamBandwidthKbps()
|
||||
return when {
|
||||
bw >= 100_000 -> TYPE_CELLULAR_5G
|
||||
bw >= 10_000 -> TYPE_CELLULAR_LTE
|
||||
else -> TYPE_CELLULAR_3G
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
/** Constants matching Rust `NetworkContext` enum ordinals. */
|
||||
const val TYPE_WIFI = 0
|
||||
const val TYPE_CELLULAR_LTE = 1
|
||||
const val TYPE_CELLULAR_5G = 2
|
||||
const val TYPE_CELLULAR_3G = 3
|
||||
const val TYPE_UNKNOWN = 4
|
||||
const val TYPE_NONE = 5
|
||||
}
|
||||
}
|
||||
12
android/app/src/main/java/com/wzp/net/RelayPinger.kt
Normal file
12
android/app/src/main/java/com/wzp/net/RelayPinger.kt
Normal file
@@ -0,0 +1,12 @@
|
||||
package com.wzp.net
|
||||
|
||||
// Relay pinging is now done via WzpEngine.pingRelay() (instance method).
|
||||
// This file kept for the data class only.
|
||||
|
||||
object RelayPinger {
|
||||
data class PingResult(
|
||||
val rttMs: Int,
|
||||
val reachable: Boolean,
|
||||
val serverFingerprint: String = "",
|
||||
)
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package com.wzp.ui.call
|
||||
|
||||
import android.Manifest
|
||||
import android.content.Intent
|
||||
import android.content.pm.PackageManager
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import android.widget.Toast
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
@@ -15,8 +17,18 @@ import androidx.compose.material3.dynamicLightColorScheme
|
||||
import androidx.compose.material3.lightColorScheme
|
||||
import androidx.compose.foundation.isSystemInDarkTheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.core.content.ContextCompat
|
||||
import androidx.core.content.FileProvider
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.lifecycle.repeatOnLifecycle
|
||||
import com.wzp.ui.settings.SettingsScreen
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* Main activity hosting the in-call Compose UI.
|
||||
@@ -26,6 +38,10 @@ import androidx.core.content.ContextCompat
|
||||
*/
|
||||
class CallActivity : ComponentActivity() {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "CallActivity"
|
||||
}
|
||||
|
||||
private val viewModel: CallViewModel by viewModels()
|
||||
|
||||
private val audioPermissionLauncher = registerForActivityResult(
|
||||
@@ -43,12 +59,19 @@ class CallActivity : ComponentActivity() {
|
||||
|
||||
setContent {
|
||||
WzpTheme {
|
||||
InCallScreen(
|
||||
viewModel = viewModel,
|
||||
onHangUp = {
|
||||
viewModel.stopCall()
|
||||
}
|
||||
)
|
||||
var showSettings by remember { mutableStateOf(false) }
|
||||
if (showSettings) {
|
||||
SettingsScreen(
|
||||
viewModel = viewModel,
|
||||
onBack = { showSettings = false }
|
||||
)
|
||||
} else {
|
||||
InCallScreen(
|
||||
viewModel = viewModel,
|
||||
onHangUp = { viewModel.stopCall() },
|
||||
onOpenSettings = { showSettings = true }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +80,45 @@ class CallActivity : ComponentActivity() {
|
||||
) {
|
||||
audioPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
|
||||
}
|
||||
|
||||
// Watch for debug zip ready → launch email intent
|
||||
lifecycleScope.launch {
|
||||
repeatOnLifecycle(Lifecycle.State.STARTED) {
|
||||
viewModel.debugZipReady.collect { zipFile ->
|
||||
if (zipFile != null && zipFile.exists()) {
|
||||
Log.i(TAG, "debug zip ready: ${zipFile.absolutePath} (${zipFile.length()} bytes)")
|
||||
launchEmailIntent(zipFile)
|
||||
viewModel.onDebugReportSent()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun launchEmailIntent(zipFile: java.io.File) {
|
||||
try {
|
||||
val authority = "${applicationContext.packageName}.fileprovider"
|
||||
Log.i(TAG, "FileProvider authority: $authority, file: ${zipFile.absolutePath}")
|
||||
val uri = FileProvider.getUriForFile(this, authority, zipFile)
|
||||
Log.i(TAG, "FileProvider URI: $uri")
|
||||
|
||||
val intent = Intent(Intent.ACTION_SEND).apply {
|
||||
type = "message/rfc822"
|
||||
putExtra(Intent.EXTRA_EMAIL, arrayOf("manwefarm@gmail.com"))
|
||||
putExtra(Intent.EXTRA_SUBJECT, "WZ Phone Debug Report - ${zipFile.name}")
|
||||
putExtra(
|
||||
Intent.EXTRA_TEXT,
|
||||
"Debug report attached.\n\nContains: call recordings (WAV), RMS histograms (CSV), logcat, stats."
|
||||
)
|
||||
putExtra(Intent.EXTRA_STREAM, uri)
|
||||
addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
|
||||
}
|
||||
startActivity(Intent.createChooser(intent, "Send debug report"))
|
||||
Log.i(TAG, "email intent launched")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "email intent failed", e)
|
||||
Toast.makeText(this, "Failed to launch email: ${e.message}", Toast.LENGTH_LONG).show()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onDestroy() {
|
||||
|
||||
@@ -5,11 +5,16 @@ import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.wzp.audio.AudioPipeline
|
||||
import com.wzp.audio.AudioRoute
|
||||
import com.wzp.audio.AudioRouteManager
|
||||
import com.wzp.data.SettingsRepository
|
||||
import com.wzp.debug.DebugReporter
|
||||
import com.wzp.engine.CallStats
|
||||
import com.wzp.service.CallService
|
||||
import com.wzp.engine.WzpCallback
|
||||
import com.wzp.engine.WzpEngine
|
||||
import com.wzp.net.NetworkMonitor
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
@@ -17,20 +22,37 @@ import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.json.JSONObject
|
||||
import java.io.File
|
||||
import java.net.Inet4Address
|
||||
import java.net.Inet6Address
|
||||
import java.net.InetAddress
|
||||
|
||||
data class ServerEntry(val address: String, val label: String)
|
||||
|
||||
data class PingResult(
|
||||
val rttMs: Int,
|
||||
val serverFingerprint: String = "",
|
||||
val reachable: Boolean = rttMs > 0,
|
||||
)
|
||||
|
||||
enum class LockStatus { UNKNOWN, OFFLINE, NEW, VERIFIED, CHANGED }
|
||||
|
||||
class CallViewModel : ViewModel(), WzpCallback {
|
||||
|
||||
private var engine: WzpEngine? = null
|
||||
private var engineInitialized = false
|
||||
private var audioPipeline: AudioPipeline? = null
|
||||
private var audioRouteManager: AudioRouteManager? = null
|
||||
private var networkMonitor: NetworkMonitor? = null
|
||||
private var audioStarted = false
|
||||
private var appContext: Context? = null
|
||||
private var settings: SettingsRepository? = null
|
||||
private var debugReporter: DebugReporter? = null
|
||||
private var lastStatsJson: String = "{}"
|
||||
private var lastCallDuration: Double = 0.0
|
||||
private var lastCallServer: String = ""
|
||||
|
||||
private val _callState = MutableStateFlow(0)
|
||||
val callState: StateFlow<Int> get() = _callState.asStateFlow()
|
||||
@@ -41,6 +63,9 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
private val _isSpeaker = MutableStateFlow(false)
|
||||
val isSpeaker: StateFlow<Boolean> = _isSpeaker.asStateFlow()
|
||||
|
||||
private val _audioRoute = MutableStateFlow(AudioRoute.EARPIECE)
|
||||
val audioRoute: StateFlow<AudioRoute> = _audioRoute.asStateFlow()
|
||||
|
||||
private val _stats = MutableStateFlow(CallStats())
|
||||
val stats: StateFlow<CallStats> = _stats.asStateFlow()
|
||||
|
||||
@@ -62,21 +87,142 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
private val _preferIPv6 = MutableStateFlow(false)
|
||||
val preferIPv6: StateFlow<Boolean> = _preferIPv6.asStateFlow()
|
||||
|
||||
private val _recentRooms = MutableStateFlow<List<com.wzp.data.SettingsRepository.RecentRoom>>(emptyList())
|
||||
val recentRooms: StateFlow<List<com.wzp.data.SettingsRepository.RecentRoom>> = _recentRooms.asStateFlow()
|
||||
|
||||
/** Ping results keyed by server address. */
|
||||
private val _pingResults = MutableStateFlow<Map<String, PingResult>>(emptyMap())
|
||||
val pingResults: StateFlow<Map<String, PingResult>> = _pingResults.asStateFlow()
|
||||
|
||||
/** Known server fingerprints (TOFU). */
|
||||
private val _knownFingerprints = MutableStateFlow<Map<String, String>>(emptyMap())
|
||||
|
||||
private val _playoutGainDb = MutableStateFlow(0f)
|
||||
val playoutGainDb: StateFlow<Float> = _playoutGainDb.asStateFlow()
|
||||
|
||||
private val _captureGainDb = MutableStateFlow(0f)
|
||||
val captureGainDb: StateFlow<Float> = _captureGainDb.asStateFlow()
|
||||
|
||||
private val _alias = MutableStateFlow("")
|
||||
val alias: StateFlow<String> = _alias.asStateFlow()
|
||||
|
||||
private val _seedHex = MutableStateFlow("")
|
||||
val seedHex: StateFlow<String> = _seedHex.asStateFlow()
|
||||
|
||||
private val _aecEnabled = MutableStateFlow(true)
|
||||
val aecEnabled: StateFlow<Boolean> = _aecEnabled.asStateFlow()
|
||||
|
||||
private val _debugRecording = MutableStateFlow(false)
|
||||
val debugRecording: StateFlow<Boolean> = _debugRecording.asStateFlow()
|
||||
|
||||
// Quality profile index (matches JNI bridge profile_from_int)
|
||||
private val _codecChoice = MutableStateFlow(0)
|
||||
val codecChoice: StateFlow<Int> = _codecChoice.asStateFlow()
|
||||
|
||||
/** Key-change warning dialog state. */
|
||||
data class KeyWarningInfo(val address: String, val oldFp: String, val newFp: String)
|
||||
private val _keyWarning = MutableStateFlow<KeyWarningInfo?>(null)
|
||||
val keyWarning: StateFlow<KeyWarningInfo?> = _keyWarning.asStateFlow()
|
||||
|
||||
/** True when a call just ended and debug report can be sent. */
|
||||
private val _debugReportAvailable = MutableStateFlow(false)
|
||||
val debugReportAvailable: StateFlow<Boolean> = _debugReportAvailable.asStateFlow()
|
||||
|
||||
/** Status: null=idle, "Preparing..."=in progress, "ready"=zip ready, "Error:..."=failed */
|
||||
private val _debugReportStatus = MutableStateFlow<String?>(null)
|
||||
val debugReportStatus: StateFlow<String?> = _debugReportStatus.asStateFlow()
|
||||
|
||||
/** The zip file ready to be emailed. Set by sendDebugReport, consumed by Activity. */
|
||||
private val _debugZipReady = MutableStateFlow<File?>(null)
|
||||
val debugZipReady: StateFlow<File?> = _debugZipReady.asStateFlow()
|
||||
|
||||
private var statsJob: Job? = null
|
||||
|
||||
// ── Direct calling state ──
|
||||
/** 0=room mode, 1=direct call mode */
|
||||
private val _callMode = MutableStateFlow(0)
|
||||
val callMode: StateFlow<Int> = _callMode.asStateFlow()
|
||||
|
||||
/** Target fingerprint for direct call */
|
||||
private val _targetFingerprint = MutableStateFlow("")
|
||||
val targetFingerprint: StateFlow<String> = _targetFingerprint.asStateFlow()
|
||||
|
||||
/** Signal connection state: 0=idle, 5=registered, 6=ringing, 7=incoming */
|
||||
private val _signalState = MutableStateFlow(0)
|
||||
val signalState: StateFlow<Int> = _signalState.asStateFlow()
|
||||
|
||||
/** Incoming call info */
|
||||
private val _incomingCallId = MutableStateFlow<String?>(null)
|
||||
val incomingCallId: StateFlow<String?> = _incomingCallId.asStateFlow()
|
||||
|
||||
private val _incomingCallerFp = MutableStateFlow<String?>(null)
|
||||
val incomingCallerFp: StateFlow<String?> = _incomingCallerFp.asStateFlow()
|
||||
|
||||
private val _incomingCallerAlias = MutableStateFlow<String?>(null)
|
||||
val incomingCallerAlias: StateFlow<String?> = _incomingCallerAlias.asStateFlow()
|
||||
|
||||
fun setCallMode(mode: Int) { _callMode.value = mode }
|
||||
fun setTargetFingerprint(fp: String) { _targetFingerprint.value = fp }
|
||||
|
||||
/** Register on relay for direct calls */
|
||||
fun registerForCalls() {
|
||||
if (engine == null) {
|
||||
engine = WzpEngine(this).also { it.init() }
|
||||
}
|
||||
val serverIdx = _selectedServer.value
|
||||
val serverList = _servers.value
|
||||
if (serverIdx >= serverList.size) return
|
||||
|
||||
val relay = serverList[serverIdx].address
|
||||
val seed = _seedHex.value
|
||||
val alias = _alias.value
|
||||
|
||||
viewModelScope.launch(Dispatchers.IO) {
|
||||
val resolvedRelay = resolveToIp(relay) ?: relay
|
||||
val result = engine?.startSignaling(resolvedRelay, seed, "", alias)
|
||||
if (result == 0) {
|
||||
_signalState.value = 5 // Registered
|
||||
startStatsPolling()
|
||||
} else {
|
||||
_errorMessage.value = "Failed to register on relay"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Place a direct call to the target fingerprint */
|
||||
fun placeDirectCall() {
|
||||
val target = _targetFingerprint.value.trim()
|
||||
if (target.isEmpty()) {
|
||||
_errorMessage.value = "Enter a fingerprint to call"
|
||||
return
|
||||
}
|
||||
engine?.placeCall(target)
|
||||
_signalState.value = 6 // Ringing
|
||||
}
|
||||
|
||||
/** Answer an incoming direct call */
|
||||
fun answerIncomingCall(mode: Int = 2) {
|
||||
val callId = _incomingCallId.value ?: return
|
||||
engine?.answerCall(callId, mode)
|
||||
}
|
||||
|
||||
/** Reject an incoming direct call */
|
||||
fun rejectIncomingCall() {
|
||||
val callId = _incomingCallId.value ?: return
|
||||
engine?.answerCall(callId, 0) // 0 = Reject
|
||||
_signalState.value = 5 // Back to registered
|
||||
_incomingCallId.value = null
|
||||
_incomingCallerFp.value = null
|
||||
_incomingCallerAlias.value = null
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val TAG = "WzpCall"
|
||||
val DEFAULT_SERVERS = listOf(
|
||||
ServerEntry("172.16.81.175:4433", "LAN (172.16.81.175)"),
|
||||
ServerEntry("193.180.213.68:4433", "Pangolin (IP)"),
|
||||
)
|
||||
const val DEFAULT_ROOM = "android"
|
||||
const val DEFAULT_ROOM = "general"
|
||||
}
|
||||
|
||||
fun setContext(context: Context) {
|
||||
@@ -86,22 +232,64 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
audioPipeline = AudioPipeline(appCtx)
|
||||
}
|
||||
if (audioRouteManager == null) {
|
||||
audioRouteManager = AudioRouteManager(appCtx)
|
||||
audioRouteManager = AudioRouteManager(appCtx).also { arm ->
|
||||
arm.onRouteChanged = { route ->
|
||||
_audioRoute.value = route
|
||||
_isSpeaker.value = (route == AudioRoute.SPEAKER)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (networkMonitor == null) {
|
||||
networkMonitor = NetworkMonitor(appCtx).also { nm ->
|
||||
nm.onNetworkChanged = { type, bw ->
|
||||
engine?.onNetworkChanged(type, bw)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (debugReporter == null) {
|
||||
debugReporter = DebugReporter(appCtx)
|
||||
}
|
||||
if (settings == null) {
|
||||
settings = SettingsRepository(appCtx)
|
||||
loadSettings()
|
||||
}
|
||||
}
|
||||
|
||||
private fun loadSettings() {
|
||||
val s = settings ?: return
|
||||
s.loadServers()?.let { saved ->
|
||||
if (saved.isNotEmpty()) _servers.value = saved
|
||||
}
|
||||
_selectedServer.value = s.loadSelectedServer().coerceIn(0, _servers.value.lastIndex)
|
||||
_roomName.value = s.loadRoom()
|
||||
_alias.value = s.getOrCreateAlias()
|
||||
_preferIPv6.value = s.loadPreferIPv6()
|
||||
_playoutGainDb.value = s.loadPlayoutGain()
|
||||
_captureGainDb.value = s.loadCaptureGain()
|
||||
_seedHex.value = s.getOrCreateSeedHex()
|
||||
_aecEnabled.value = s.loadAecEnabled()
|
||||
_debugRecording.value = s.loadDebugRecording()
|
||||
_codecChoice.value = s.loadCodecChoice()
|
||||
_recentRooms.value = s.loadRecentRooms()
|
||||
}
|
||||
|
||||
fun selectServer(index: Int) {
|
||||
if (index in _servers.value.indices) {
|
||||
_selectedServer.value = index
|
||||
settings?.saveSelectedServer(index)
|
||||
}
|
||||
}
|
||||
|
||||
fun setPreferIPv6(prefer: Boolean) { _preferIPv6.value = prefer }
|
||||
fun setPreferIPv6(prefer: Boolean) {
|
||||
_preferIPv6.value = prefer
|
||||
settings?.savePreferIPv6(prefer)
|
||||
}
|
||||
|
||||
fun addServer(hostPort: String, label: String) {
|
||||
val current = _servers.value.toMutableList()
|
||||
current.add(ServerEntry(hostPort, label))
|
||||
_servers.value = current
|
||||
settings?.saveServers(current)
|
||||
}
|
||||
|
||||
fun removeServer(index: Int) {
|
||||
@@ -113,19 +301,123 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
if (_selectedServer.value >= current.size) {
|
||||
_selectedServer.value = 0
|
||||
}
|
||||
settings?.saveServers(current)
|
||||
settings?.saveSelectedServer(_selectedServer.value)
|
||||
}
|
||||
}
|
||||
|
||||
fun setRoomName(name: String) { _roomName.value = name }
|
||||
/** Batch-apply servers and selection from Settings draft state. */
|
||||
fun applyServers(servers: List<ServerEntry>, selected: Int) {
|
||||
_servers.value = servers
|
||||
_selectedServer.value = selected.coerceIn(0, servers.lastIndex)
|
||||
settings?.saveServers(servers)
|
||||
settings?.saveSelectedServer(_selectedServer.value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Ping all servers via native QUIC. Requires engine to be initialized.
|
||||
* Creates engine if needed, pings, keeps engine alive for subsequent Connect.
|
||||
*/
|
||||
fun pingAllServers() {
|
||||
viewModelScope.launch {
|
||||
// Ensure engine exists
|
||||
if (engine == null || engine?.isInitialized != true) {
|
||||
try {
|
||||
engine = WzpEngine(this@CallViewModel).also { it.init() }
|
||||
engineInitialized = true
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "engine init for ping failed: $e")
|
||||
return@launch
|
||||
}
|
||||
}
|
||||
val eng = engine ?: return@launch
|
||||
|
||||
val results = mutableMapOf<String, PingResult>()
|
||||
val known = mutableMapOf<String, String>()
|
||||
_servers.value.forEach { server ->
|
||||
val json = withContext(Dispatchers.IO) {
|
||||
eng.pingRelay(server.address)
|
||||
}
|
||||
if (json != null) {
|
||||
try {
|
||||
val obj = JSONObject(json)
|
||||
val rtt = obj.getInt("rtt_ms")
|
||||
val fp = obj.optString("server_fingerprint", "")
|
||||
results[server.address] = PingResult(rttMs = rtt, serverFingerprint = fp)
|
||||
// TOFU
|
||||
if (fp.isNotEmpty()) {
|
||||
val saved = settings?.loadServerFingerprint(server.address)
|
||||
if (saved == null) settings?.saveServerFingerprint(server.address, fp)
|
||||
known[server.address] = saved ?: fp
|
||||
}
|
||||
} catch (_: Exception) {}
|
||||
}
|
||||
}
|
||||
_pingResults.value = results
|
||||
_knownFingerprints.value = known
|
||||
}
|
||||
}
|
||||
|
||||
/** Load saved TOFU fingerprints. */
|
||||
fun loadSavedFingerprints() {
|
||||
val known = mutableMapOf<String, String>()
|
||||
_servers.value.forEach { server ->
|
||||
settings?.loadServerFingerprint(server.address)?.let {
|
||||
known[server.address] = it
|
||||
}
|
||||
}
|
||||
_knownFingerprints.value = known
|
||||
}
|
||||
|
||||
/** Get lock status for a server. */
|
||||
fun lockStatus(address: String): LockStatus {
|
||||
val pr = _pingResults.value[address] ?: return LockStatus.UNKNOWN
|
||||
if (!pr.reachable) return LockStatus.OFFLINE
|
||||
val known = _knownFingerprints.value[address] ?: return LockStatus.NEW
|
||||
if (pr.serverFingerprint.isEmpty()) return LockStatus.NEW
|
||||
return if (pr.serverFingerprint == known) LockStatus.VERIFIED else LockStatus.CHANGED
|
||||
}
|
||||
|
||||
fun setRoomName(name: String) {
|
||||
_roomName.value = name
|
||||
settings?.saveRoom(name)
|
||||
}
|
||||
|
||||
fun setPlayoutGainDb(db: Float) {
|
||||
_playoutGainDb.value = db
|
||||
audioPipeline?.playoutGainDb = db
|
||||
settings?.savePlayoutGain(db)
|
||||
}
|
||||
|
||||
fun setCaptureGainDb(db: Float) {
|
||||
_captureGainDb.value = db
|
||||
audioPipeline?.captureGainDb = db
|
||||
settings?.saveCaptureGain(db)
|
||||
}
|
||||
|
||||
fun setAlias(alias: String) {
|
||||
_alias.value = alias
|
||||
settings?.saveAlias(alias)
|
||||
}
|
||||
|
||||
fun restoreSeed(hex: String) {
|
||||
_seedHex.value = hex
|
||||
settings?.saveSeedHex(hex)
|
||||
}
|
||||
|
||||
fun setAecEnabled(enabled: Boolean) {
|
||||
_aecEnabled.value = enabled
|
||||
settings?.saveAecEnabled(enabled)
|
||||
}
|
||||
|
||||
fun setDebugRecording(enabled: Boolean) {
|
||||
_debugRecording.value = enabled
|
||||
settings?.saveDebugRecording(enabled)
|
||||
}
|
||||
|
||||
fun setCodecChoice(choice: Int) {
|
||||
_codecChoice.value = choice
|
||||
settings?.saveCodecChoice(choice)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -166,25 +458,111 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
/** Tear down engine and audio. Pass stopService=true to also stop the foreground service. */
|
||||
private fun teardown(stopService: Boolean = true) {
|
||||
Log.i(TAG, "teardown: stopping audio, stopService=$stopService")
|
||||
val hadCall = audioStarted
|
||||
CallService.onStopFromNotification = null
|
||||
stopAudio()
|
||||
stopAudio() // sets running=false (non-blocking)
|
||||
stopStatsPolling()
|
||||
|
||||
// Wait for audio threads to exit their loops before destroying the engine.
|
||||
// This guarantees no in-flight JNI calls to writeAudio/readAudio.
|
||||
val drained = audioPipeline?.awaitDrain() ?: true
|
||||
if (!drained) {
|
||||
Log.w(TAG, "teardown: audio threads did not drain in time")
|
||||
}
|
||||
audioPipeline = null
|
||||
|
||||
Log.i(TAG, "teardown: stopping engine")
|
||||
try { engine?.stopCall() } catch (e: Exception) { Log.w(TAG, "stopCall err: $e") }
|
||||
try { engine?.destroy() } catch (e: Exception) { Log.w(TAG, "destroy err: $e") }
|
||||
engine = null
|
||||
engineInitialized = false
|
||||
_callState.value = 0
|
||||
if (hadCall) {
|
||||
_debugReportAvailable.value = true
|
||||
}
|
||||
if (stopService) {
|
||||
try { appContext?.let { CallService.stop(it) } } catch (_: Exception) {}
|
||||
}
|
||||
Log.i(TAG, "teardown: done")
|
||||
}
|
||||
|
||||
/** Accept the new server key and proceed with the call. */
|
||||
fun acceptNewFingerprint() {
|
||||
val info = _keyWarning.value ?: return
|
||||
_knownFingerprints.value = _knownFingerprints.value.toMutableMap().also {
|
||||
it[info.address] = info.newFp
|
||||
}
|
||||
settings?.saveServerFingerprint(info.address, info.newFp)
|
||||
_keyWarning.value = null
|
||||
startCallInternal()
|
||||
}
|
||||
|
||||
fun dismissKeyWarning() {
|
||||
_keyWarning.value = null
|
||||
}
|
||||
|
||||
fun startCall() {
|
||||
val serverEntry = _servers.value[_selectedServer.value]
|
||||
// Check for key change before connecting
|
||||
val ls = lockStatus(serverEntry.address)
|
||||
if (ls == LockStatus.CHANGED) {
|
||||
val known = _knownFingerprints.value[serverEntry.address] ?: ""
|
||||
val current = _pingResults.value[serverEntry.address]?.serverFingerprint ?: ""
|
||||
_keyWarning.value = KeyWarningInfo(serverEntry.address, known, current)
|
||||
return
|
||||
}
|
||||
startCallInternal()
|
||||
}
|
||||
|
||||
/** Start a call to a specific relay + room (used by direct call setup). */
|
||||
private fun startCallInternal(relay: String, room: String) {
|
||||
Log.i(TAG, "startCallDirect: relay=$relay room=$room")
|
||||
try {
|
||||
// Don't teardown — keep the signal connection alive
|
||||
engine = WzpEngine(this)
|
||||
engine!!.init()
|
||||
engineInitialized = true
|
||||
_callState.value = 1
|
||||
_errorMessage.value = null
|
||||
try { appContext?.let { CallService.start(it) } } catch (e: Exception) {
|
||||
Log.w(TAG, "service start err: $e")
|
||||
}
|
||||
startStatsPolling()
|
||||
viewModelScope.launch(kotlinx.coroutines.Dispatchers.IO) {
|
||||
try {
|
||||
val seed = _seedHex.value
|
||||
val name = _alias.value
|
||||
val result = engine?.startCall(relay, room, seedHex = seed, alias = name, profile = _codecChoice.value) ?: -1
|
||||
CallService.onStopFromNotification = { stopCall() }
|
||||
if (result != 0) {
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Failed to connect to call room (code $result)"
|
||||
appContext?.let { CallService.stop(it) }
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "startCallDirect error", e)
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Engine error: ${e.message}"
|
||||
appContext?.let { CallService.stop(it) }
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "startCallDirect error", e)
|
||||
_callState.value = 0
|
||||
_errorMessage.value = "Engine error: ${e.message}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun startCallInternal() {
|
||||
val serverEntry = _servers.value[_selectedServer.value]
|
||||
val room = _roomName.value
|
||||
Log.i(TAG, "startCall: server=${serverEntry.address} room=$room")
|
||||
_debugReportAvailable.value = false
|
||||
_debugReportStatus.value = null
|
||||
lastCallServer = serverEntry.address
|
||||
settings?.addRecentRoom(serverEntry.address, room)
|
||||
_recentRooms.value = settings?.loadRecentRooms() ?: emptyList()
|
||||
debugReporter?.prepareForCall()
|
||||
try {
|
||||
// Teardown previous call but don't stop the service (we're about to restart it)
|
||||
teardown(stopService = false)
|
||||
@@ -203,8 +581,10 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
viewModelScope.launch(kotlinx.coroutines.Dispatchers.IO) {
|
||||
try {
|
||||
val relay = resolveToIp(serverEntry.address)
|
||||
Log.i(TAG, "startCall: resolved=$relay, calling engine.startCall")
|
||||
val result = engine?.startCall(relay, room) ?: -1
|
||||
val seed = _seedHex.value
|
||||
val name = _alias.value
|
||||
Log.i(TAG, "startCall: resolved=$relay, alias=$name, calling engine.startCall")
|
||||
val result = engine?.startCall(relay, room, seedHex = seed, alias = name, profile = _codecChoice.value) ?: -1
|
||||
Log.i(TAG, "startCall: engine returned $result")
|
||||
// Only wire up notification callback after engine is running
|
||||
CallService.onStopFromNotification = { stopCall() }
|
||||
@@ -245,8 +625,63 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
audioRouteManager?.setSpeaker(newSpeaker)
|
||||
}
|
||||
|
||||
/** Cycle audio output: Earpiece → Speaker → Bluetooth (if available) → Earpiece. */
|
||||
fun cycleAudioRoute() {
|
||||
val routes = audioRouteManager?.availableRoutes() ?: return
|
||||
val currentIdx = routes.indexOf(_audioRoute.value)
|
||||
val next = routes[(currentIdx + 1) % routes.size]
|
||||
when (next) {
|
||||
AudioRoute.EARPIECE -> {
|
||||
audioRouteManager?.setBluetoothSco(false)
|
||||
audioRouteManager?.setSpeaker(false)
|
||||
}
|
||||
AudioRoute.SPEAKER -> {
|
||||
audioRouteManager?.setSpeaker(true)
|
||||
}
|
||||
AudioRoute.BLUETOOTH -> {
|
||||
audioRouteManager?.setBluetoothSco(true)
|
||||
}
|
||||
}
|
||||
_audioRoute.value = next
|
||||
_isSpeaker.value = (next == AudioRoute.SPEAKER)
|
||||
}
|
||||
|
||||
fun clearError() { _errorMessage.value = null }
|
||||
|
||||
fun sendDebugReport() {
|
||||
val reporter = debugReporter ?: return
|
||||
_debugReportStatus.value = "Preparing debug report..."
|
||||
viewModelScope.launch(kotlinx.coroutines.Dispatchers.IO) {
|
||||
val zipFile = reporter.collectZip(
|
||||
callDurationSecs = lastCallDuration,
|
||||
finalStatsJson = lastStatsJson,
|
||||
aecEnabled = _aecEnabled.value,
|
||||
alias = _alias.value,
|
||||
server = lastCallServer,
|
||||
room = _roomName.value
|
||||
)
|
||||
if (zipFile != null) {
|
||||
_debugZipReady.value = zipFile
|
||||
_debugReportStatus.value = "ready"
|
||||
} else {
|
||||
_debugReportStatus.value = "Error: failed to create zip"
|
||||
}
|
||||
_debugReportAvailable.value = false
|
||||
}
|
||||
}
|
||||
|
||||
/** Called by Activity after email intent is launched. */
|
||||
fun onDebugReportSent() {
|
||||
_debugZipReady.value = null
|
||||
_debugReportStatus.value = null
|
||||
}
|
||||
|
||||
fun dismissDebugReport() {
|
||||
_debugReportAvailable.value = false
|
||||
_debugReportStatus.value = null
|
||||
_debugZipReady.value = null
|
||||
}
|
||||
|
||||
// WzpCallback
|
||||
override fun onCallStateChanged(state: Int) { _callState.value = state }
|
||||
override fun onQualityTierChanged(tier: Int) { _qualityTier.value = tier }
|
||||
@@ -260,19 +695,23 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
audioPipeline = AudioPipeline(ctx).also {
|
||||
it.playoutGainDb = _playoutGainDb.value
|
||||
it.captureGainDb = _captureGainDb.value
|
||||
it.aecEnabled = _aecEnabled.value
|
||||
it.debugRecording = _debugRecording.value
|
||||
it.start(e)
|
||||
}
|
||||
audioRouteManager?.register()
|
||||
networkMonitor?.register()
|
||||
audioStarted = true
|
||||
}
|
||||
|
||||
private fun stopAudio() {
|
||||
if (!audioStarted) return
|
||||
audioPipeline?.stop()
|
||||
audioPipeline = null
|
||||
audioPipeline?.stop() // sets running=false; DON'T null — teardown needs awaitDrain()
|
||||
audioRouteManager?.unregister()
|
||||
networkMonitor?.unregister()
|
||||
audioRouteManager?.setSpeaker(false)
|
||||
_isSpeaker.value = false
|
||||
_audioRoute.value = AudioRoute.EARPIECE
|
||||
audioStarted = false
|
||||
}
|
||||
|
||||
@@ -284,11 +723,34 @@ class CallViewModel : ViewModel(), WzpCallback {
|
||||
val json = engine?.getStats() ?: "{}"
|
||||
if (json.isNotEmpty()) {
|
||||
Log.d(TAG, "raw: $json")
|
||||
lastStatsJson = json
|
||||
val s = CallStats.fromJson(json)
|
||||
lastCallDuration = s.durationSecs
|
||||
_stats.value = s
|
||||
if (s.state != 0) {
|
||||
_callState.value = s.state
|
||||
}
|
||||
// Track signal state changes for direct calling
|
||||
if (s.state in 5..7) {
|
||||
_signalState.value = s.state
|
||||
}
|
||||
// Incoming call detection
|
||||
if (s.state == 7) { // IncomingCall
|
||||
_incomingCallId.value = s.incomingCallId
|
||||
_incomingCallerFp.value = s.incomingCallerFp
|
||||
_incomingCallerAlias.value = s.incomingCallerAlias
|
||||
}
|
||||
// CallSetup: auto-connect to media room
|
||||
if (s.state == 1 && s.incomingCallId != null && s.incomingCallId.contains("|")) {
|
||||
// Format: "relay_addr|room_name"
|
||||
val parts = s.incomingCallId.split("|", limit = 2)
|
||||
if (parts.size == 2) {
|
||||
val mediaRelay = parts[0]
|
||||
val mediaRoom = parts[1]
|
||||
Log.i(TAG, "CallSetup: connecting to $mediaRelay room $mediaRoom")
|
||||
startCallInternal(mediaRelay, mediaRoom)
|
||||
}
|
||||
}
|
||||
if (s.state == 2 && !audioStarted) {
|
||||
startAudio()
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
141
android/app/src/main/java/com/wzp/ui/components/Identicon.kt
Normal file
141
android/app/src/main/java/com/wzp/ui/components/Identicon.kt
Normal file
@@ -0,0 +1,141 @@
|
||||
package com.wzp.ui.components
|
||||
|
||||
import android.widget.Toast
|
||||
import androidx.compose.foundation.Canvas
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.geometry.Size
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import androidx.compose.ui.unit.dp
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* Deterministic identicon — generates a unique 5x5 symmetric pattern
|
||||
* from a hex fingerprint string. Identical algorithm to the desktop
|
||||
* TypeScript implementation in identicon.ts.
|
||||
*/
|
||||
@Composable
|
||||
fun Identicon(
|
||||
fingerprint: String,
|
||||
size: Dp = 36.dp,
|
||||
clickToCopy: Boolean = true,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
val clipboard = LocalClipboardManager.current
|
||||
val context = LocalContext.current
|
||||
val bytes = hashBytes(fingerprint)
|
||||
val (bg, fg) = deriveColors(bytes)
|
||||
val grid = buildGrid(bytes)
|
||||
|
||||
Canvas(
|
||||
modifier = modifier
|
||||
.size(size)
|
||||
.clip(RoundedCornerShape(size * 0.12f))
|
||||
.then(
|
||||
if (clickToCopy && fingerprint.isNotEmpty()) {
|
||||
Modifier.clickable {
|
||||
clipboard.setText(AnnotatedString(fingerprint))
|
||||
Toast.makeText(context, "Copied", Toast.LENGTH_SHORT).show()
|
||||
}
|
||||
} else Modifier
|
||||
)
|
||||
) {
|
||||
val cellW = this.size.width / 5f
|
||||
val cellH = this.size.height / 5f
|
||||
|
||||
// Background
|
||||
drawRect(color = bg, size = this.size)
|
||||
|
||||
// Foreground cells
|
||||
for (y in 0 until 5) {
|
||||
for (x in 0 until 5) {
|
||||
if (grid[y][x]) {
|
||||
drawRect(
|
||||
color = fg,
|
||||
topLeft = Offset(x * cellW, y * cellH),
|
||||
size = Size(cellW, cellH),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fingerprint text that copies to clipboard on tap.
|
||||
*/
|
||||
@Composable
|
||||
fun CopyableFingerprint(
|
||||
fingerprint: String,
|
||||
modifier: Modifier = Modifier,
|
||||
style: androidx.compose.ui.text.TextStyle = androidx.compose.material3.MaterialTheme.typography.bodySmall,
|
||||
color: Color = Color.Unspecified,
|
||||
) {
|
||||
val clipboard = LocalClipboardManager.current
|
||||
val context = LocalContext.current
|
||||
|
||||
androidx.compose.material3.Text(
|
||||
text = fingerprint,
|
||||
style = style,
|
||||
color = color,
|
||||
modifier = modifier.clickable {
|
||||
if (fingerprint.isNotEmpty()) {
|
||||
clipboard.setText(AnnotatedString(fingerprint))
|
||||
Toast.makeText(context, "Fingerprint copied", Toast.LENGTH_SHORT).show()
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// --- Internal helpers (matching desktop identicon.ts) ---
|
||||
|
||||
private fun hashBytes(hex: String): List<Int> {
|
||||
val clean = hex.filter { it.isLetterOrDigit() }
|
||||
val bytes = mutableListOf<Int>()
|
||||
var i = 0
|
||||
while (i + 1 < clean.length) {
|
||||
val b = clean.substring(i, i + 2).toIntOrNull(16) ?: 0
|
||||
bytes.add(b)
|
||||
i += 2
|
||||
}
|
||||
// Pad to at least 16 bytes
|
||||
while (bytes.size < 16) bytes.add(0)
|
||||
return bytes
|
||||
}
|
||||
|
||||
private fun deriveColors(bytes: List<Int>): Pair<Color, Color> {
|
||||
val hue1 = bytes[0] * 360f / 256f
|
||||
val hue2 = (bytes[1] * 360f / 256f + 120f) % 360f
|
||||
val bg = hslToColor(hue1, 0.65f, 0.35f)
|
||||
val fg = hslToColor(hue2, 0.70f, 0.55f)
|
||||
return bg to fg
|
||||
}
|
||||
|
||||
private fun buildGrid(bytes: List<Int>): List<List<Boolean>> {
|
||||
return (0 until 5).map { y ->
|
||||
val left = (0 until 3).map { x ->
|
||||
val idx = 2 + y * 3 + x
|
||||
bytes[idx % bytes.size] > 128
|
||||
}
|
||||
// Mirror: col3 = col1, col4 = col0
|
||||
listOf(left[0], left[1], left[2], left[1], left[0])
|
||||
}
|
||||
}
|
||||
|
||||
private fun hslToColor(h: Float, s: Float, l: Float): Color {
|
||||
val k = { n: Float -> (n + h / 30f) % 12f }
|
||||
val a = s * min(l, 1f - l)
|
||||
val f = { n: Float ->
|
||||
l - a * maxOf(-1f, minOf(k(n) - 3f, minOf(9f - k(n), 1f)))
|
||||
}
|
||||
return Color(f(0f), f(8f), f(4f))
|
||||
}
|
||||
567
android/app/src/main/java/com/wzp/ui/settings/SettingsScreen.kt
Normal file
567
android/app/src/main/java/com/wzp/ui/settings/SettingsScreen.kt
Normal file
@@ -0,0 +1,567 @@
|
||||
package com.wzp.ui.settings
|
||||
|
||||
import androidx.compose.foundation.clickable
|
||||
import android.content.ClipData
|
||||
import android.content.ClipboardManager
|
||||
import android.content.Context
|
||||
import android.widget.Toast
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.ExperimentalLayoutApi
|
||||
import androidx.compose.foundation.layout.FlowRow
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.Divider
|
||||
import androidx.compose.material3.RadioButton
|
||||
import androidx.compose.material3.FilledTonalButton
|
||||
import androidx.compose.material3.FilledTonalIconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
import androidx.compose.material3.OutlinedTextField
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.toMutableStateList
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.wzp.ui.call.CallViewModel
|
||||
import com.wzp.ui.call.ServerEntry
|
||||
|
||||
@OptIn(ExperimentalLayoutApi::class)
|
||||
@Composable
|
||||
fun SettingsScreen(
|
||||
viewModel: CallViewModel,
|
||||
onBack: () -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
|
||||
// Snapshot current values into local draft state
|
||||
val currentAlias by viewModel.alias.collectAsState()
|
||||
val currentSeedHex by viewModel.seedHex.collectAsState()
|
||||
val currentServers by viewModel.servers.collectAsState()
|
||||
val currentSelectedServer by viewModel.selectedServer.collectAsState()
|
||||
val currentRoomName by viewModel.roomName.collectAsState()
|
||||
val currentPreferIPv6 by viewModel.preferIPv6.collectAsState()
|
||||
val currentPlayoutGain by viewModel.playoutGainDb.collectAsState()
|
||||
val currentCaptureGain by viewModel.captureGainDb.collectAsState()
|
||||
val currentAecEnabled by viewModel.aecEnabled.collectAsState()
|
||||
|
||||
// Draft state — initialized from current values
|
||||
var draftAlias by remember { mutableStateOf(currentAlias) }
|
||||
var draftSeedHex by remember { mutableStateOf(currentSeedHex) }
|
||||
val draftServers = remember { currentServers.toMutableStateList() }
|
||||
var draftSelectedServer by remember { mutableIntStateOf(currentSelectedServer) }
|
||||
var draftRoomName by remember { mutableStateOf(currentRoomName) }
|
||||
var draftPreferIPv6 by remember { mutableStateOf(currentPreferIPv6) }
|
||||
var draftPlayoutGain by remember { mutableFloatStateOf(currentPlayoutGain) }
|
||||
var draftCaptureGain by remember { mutableFloatStateOf(currentCaptureGain) }
|
||||
var draftAecEnabled by remember { mutableStateOf(currentAecEnabled) }
|
||||
|
||||
// Track if anything changed
|
||||
val hasChanges = draftAlias != currentAlias ||
|
||||
draftSeedHex != currentSeedHex ||
|
||||
draftServers.toList() != currentServers ||
|
||||
draftSelectedServer != currentSelectedServer ||
|
||||
draftRoomName != currentRoomName ||
|
||||
draftPreferIPv6 != currentPreferIPv6 ||
|
||||
draftPlayoutGain != currentPlayoutGain ||
|
||||
draftCaptureGain != currentCaptureGain ||
|
||||
draftAecEnabled != currentAecEnabled
|
||||
|
||||
var showAddServerDialog by remember { mutableStateOf(false) }
|
||||
var showRestoreKeyDialog by remember { mutableStateOf(false) }
|
||||
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp)
|
||||
.verticalScroll(rememberScrollState())
|
||||
) {
|
||||
// Header
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
TextButton(onClick = onBack) {
|
||||
Text("< Back")
|
||||
}
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
Text(
|
||||
text = "Settings",
|
||||
style = MaterialTheme.typography.headlineSmall.copy(
|
||||
fontWeight = FontWeight.Bold
|
||||
),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
// Save button — only enabled when changes exist
|
||||
Button(
|
||||
onClick = {
|
||||
viewModel.setAlias(draftAlias)
|
||||
if (draftSeedHex != currentSeedHex) viewModel.restoreSeed(draftSeedHex)
|
||||
viewModel.applyServers(draftServers.toList(), draftSelectedServer)
|
||||
viewModel.setRoomName(draftRoomName)
|
||||
viewModel.setPreferIPv6(draftPreferIPv6)
|
||||
viewModel.setPlayoutGainDb(draftPlayoutGain)
|
||||
viewModel.setCaptureGainDb(draftCaptureGain)
|
||||
viewModel.setAecEnabled(draftAecEnabled)
|
||||
Toast.makeText(context, "Settings saved", Toast.LENGTH_SHORT).show()
|
||||
onBack()
|
||||
},
|
||||
enabled = hasChanges
|
||||
) {
|
||||
Text("Save")
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
// --- Identity ---
|
||||
SectionHeader("Identity")
|
||||
|
||||
OutlinedTextField(
|
||||
value = draftAlias,
|
||||
onValueChange = { draftAlias = it },
|
||||
label = { Text("Display Name") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Fingerprint display with identicon
|
||||
val fingerprint = if (draftSeedHex.length >= 16) draftSeedHex.take(16).uppercase() else "Not generated"
|
||||
Text(
|
||||
text = "Fingerprint",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
modifier = Modifier.padding(vertical = 4.dp)
|
||||
) {
|
||||
com.wzp.ui.components.Identicon(
|
||||
fingerprint = draftSeedHex,
|
||||
size = 40.dp,
|
||||
)
|
||||
Spacer(modifier = Modifier.width(12.dp))
|
||||
com.wzp.ui.components.CopyableFingerprint(
|
||||
fingerprint = fingerprint.chunked(4).joinToString(" "),
|
||||
style = MaterialTheme.typography.bodyMedium.copy(
|
||||
fontFamily = FontFamily.Monospace
|
||||
),
|
||||
color = MaterialTheme.colorScheme.onSurface,
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
// Key backup/restore
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) {
|
||||
FilledTonalButton(onClick = {
|
||||
val clipboard = context.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager
|
||||
clipboard.setPrimaryClip(ClipData.newPlainText("WZP Key", draftSeedHex))
|
||||
Toast.makeText(context, "Key copied to clipboard", Toast.LENGTH_SHORT).show()
|
||||
}) {
|
||||
Text("Copy Key")
|
||||
}
|
||||
OutlinedButton(onClick = { showRestoreKeyDialog = true }) {
|
||||
Text("Restore Key")
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Audio ---
|
||||
SectionHeader("Audio Defaults")
|
||||
|
||||
GainSlider(
|
||||
label = "Voice Volume",
|
||||
gainDb = draftPlayoutGain,
|
||||
onGainChange = { draftPlayoutGain = Math.round(it).toFloat() }
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
GainSlider(
|
||||
label = "Mic Gain",
|
||||
gainDb = draftCaptureGain,
|
||||
onGainChange = { draftCaptureGain = Math.round(it).toFloat() }
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(
|
||||
text = "Echo Cancellation (AEC)",
|
||||
style = MaterialTheme.typography.bodyMedium
|
||||
)
|
||||
Text(
|
||||
text = "Disable if audio sounds distorted",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
Switch(
|
||||
checked = draftAecEnabled,
|
||||
onCheckedChange = { draftAecEnabled = it }
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
// Quality selection — slider from best (studio 64k) to worst (codec2 1.2k) + auto
|
||||
val qualityLabels = listOf(
|
||||
"Studio 64k", "Studio 48k", "Studio 32k", "Auto",
|
||||
"Opus 24k", "Opus 6k", "Codec2 3.2k", "Codec2 1.2k"
|
||||
)
|
||||
// Map slider position to JNI profile int:
|
||||
// 0=Studio64k(6), 1=Studio48k(5), 2=Studio32k(4), 3=Auto(7),
|
||||
// 4=Opus24k(0), 5=Opus6k(1), 6=Codec2_3.2k(3), 7=Codec2_1.2k(2)
|
||||
val sliderToProfile = intArrayOf(6, 5, 4, 7, 0, 1, 3, 2)
|
||||
val profileToSlider = mapOf(6 to 0, 5 to 1, 4 to 2, 7 to 3, 0 to 4, 1 to 5, 3 to 6, 2 to 7)
|
||||
val qualityColors = listOf(
|
||||
Color(0xFF22C55E), Color(0xFF4ADE80), Color(0xFF86EFAC), Color(0xFFA3E635),
|
||||
Color(0xFFA3E635), Color(0xFFFACC15), Color(0xFFE97320), Color(0xFF991B1B)
|
||||
)
|
||||
val currentCodec by viewModel.codecChoice.collectAsState()
|
||||
val sliderPos = profileToSlider[currentCodec] ?: 3
|
||||
Text("Quality", style = MaterialTheme.typography.bodyMedium)
|
||||
Text(
|
||||
text = "Decode always accepts all codecs",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Text(
|
||||
text = qualityLabels[sliderPos],
|
||||
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.Bold),
|
||||
color = qualityColors[sliderPos]
|
||||
)
|
||||
Slider(
|
||||
value = sliderPos.toFloat(),
|
||||
onValueChange = { viewModel.setCodecChoice(sliderToProfile[it.toInt()]) },
|
||||
valueRange = 0f..7f,
|
||||
steps = 6,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Text("Best", style = MaterialTheme.typography.labelSmall, color = Color(0xFF22C55E))
|
||||
Text("Lowest", style = MaterialTheme.typography.labelSmall, color = Color(0xFF991B1B))
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Servers ---
|
||||
SectionHeader("Servers")
|
||||
|
||||
FlowRow(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.Start,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
draftServers.forEachIndexed { idx, entry ->
|
||||
val isSelected = draftSelectedServer == idx
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
FilledTonalIconButton(
|
||||
onClick = { draftSelectedServer = idx },
|
||||
modifier = Modifier
|
||||
.padding(end = 2.dp)
|
||||
.height(36.dp)
|
||||
.width(140.dp),
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = if (isSelected) {
|
||||
IconButtonDefaults.filledTonalIconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
} else {
|
||||
IconButtonDefaults.filledTonalIconButtonColors()
|
||||
}
|
||||
) {
|
||||
Text(
|
||||
text = entry.label,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
maxLines = 1
|
||||
)
|
||||
}
|
||||
// Show remove button for non-default servers
|
||||
if (idx >= 2) {
|
||||
TextButton(
|
||||
onClick = {
|
||||
draftServers.removeAt(idx)
|
||||
if (draftSelectedServer >= draftServers.size) {
|
||||
draftSelectedServer = 0
|
||||
}
|
||||
},
|
||||
modifier = Modifier.height(36.dp)
|
||||
) {
|
||||
Text("X", color = MaterialTheme.colorScheme.error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedButton(
|
||||
onClick = { showAddServerDialog = true },
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
) {
|
||||
Text("+ Add Server")
|
||||
}
|
||||
|
||||
// Show selected server address
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(
|
||||
text = "Default: ${draftServers.getOrNull(draftSelectedServer)?.address ?: "none"}",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Network ---
|
||||
SectionHeader("Network")
|
||||
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Text(
|
||||
text = "Prefer IPv6",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
Switch(
|
||||
checked = draftPreferIPv6,
|
||||
onCheckedChange = { draftPreferIPv6 = it }
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
Divider()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// --- Room ---
|
||||
SectionHeader("Room")
|
||||
|
||||
OutlinedTextField(
|
||||
value = draftRoomName,
|
||||
onValueChange = { draftRoomName = it },
|
||||
label = { Text("Default Room") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
}
|
||||
}
|
||||
|
||||
if (showAddServerDialog) {
|
||||
AddServerDialog(
|
||||
onDismiss = { showAddServerDialog = false },
|
||||
onAdd = { host, port, label ->
|
||||
draftServers.add(ServerEntry("$host:$port", label))
|
||||
showAddServerDialog = false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
if (showRestoreKeyDialog) {
|
||||
RestoreKeyDialog(
|
||||
onDismiss = { showRestoreKeyDialog = false },
|
||||
onRestore = { hex ->
|
||||
draftSeedHex = hex
|
||||
showRestoreKeyDialog = false
|
||||
Toast.makeText(context, "Key staged — press Save to apply", Toast.LENGTH_SHORT).show()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun SectionHeader(title: String) {
|
||||
Text(
|
||||
text = title,
|
||||
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.Bold),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun GainSlider(label: String, gainDb: Float, onGainChange: (Float) -> Unit) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
val sign = if (gainDb >= 0) "+" else ""
|
||||
Text(
|
||||
text = "$label: ${sign}${"%.0f".format(gainDb)} dB",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Slider(
|
||||
value = gainDb,
|
||||
onValueChange = onGainChange,
|
||||
valueRange = -20f..20f,
|
||||
steps = 0,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun AddServerDialog(
|
||||
onDismiss: () -> Unit,
|
||||
onAdd: (host: String, port: String, label: String) -> Unit
|
||||
) {
|
||||
var host by remember { mutableStateOf("") }
|
||||
var port by remember { mutableStateOf("4433") }
|
||||
var label by remember { mutableStateOf("") }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Add Server") },
|
||||
text = {
|
||||
Column {
|
||||
OutlinedTextField(
|
||||
value = host,
|
||||
onValueChange = { host = it },
|
||||
label = { Text("Host (IP or domain)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = port,
|
||||
onValueChange = { port = it },
|
||||
label = { Text("Port") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = label,
|
||||
onValueChange = { label = it },
|
||||
label = { Text("Label (optional)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
if (host.isNotBlank()) {
|
||||
val displayLabel = label.ifBlank { host }
|
||||
onAdd(host.trim(), port.trim(), displayLabel)
|
||||
}
|
||||
}
|
||||
) { Text("Add") }
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) { Text("Cancel") }
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun RestoreKeyDialog(
|
||||
onDismiss: () -> Unit,
|
||||
onRestore: (hex: String) -> Unit
|
||||
) {
|
||||
var keyInput by remember { mutableStateOf("") }
|
||||
var error by remember { mutableStateOf<String?>(null) }
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = { Text("Restore Identity Key") },
|
||||
text = {
|
||||
Column {
|
||||
Text(
|
||||
text = "Paste your 64-character hex key below. This will replace your current identity.",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = keyInput,
|
||||
onValueChange = {
|
||||
keyInput = it.trim().lowercase()
|
||||
error = null
|
||||
},
|
||||
label = { Text("Identity Key (hex)") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
isError = error != null
|
||||
)
|
||||
error?.let {
|
||||
Text(
|
||||
text = it,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
val cleaned = keyInput.replace("\\s".toRegex(), "")
|
||||
if (cleaned.length != 64 || !cleaned.all { it in '0'..'9' || it in 'a'..'f' }) {
|
||||
error = "Key must be exactly 64 hex characters"
|
||||
} else {
|
||||
onRestore(cleaned)
|
||||
}
|
||||
}
|
||||
) { Text("Restore") }
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) { Text("Cancel") }
|
||||
}
|
||||
)
|
||||
}
|
||||
4
android/app/src/main/res/xml/file_paths.xml
Normal file
4
android/app/src/main/res/xml/file_paths.xml
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<paths>
|
||||
<cache-path name="debug" path="." />
|
||||
</paths>
|
||||
@@ -17,7 +17,7 @@ wzp-crypto = { workspace = true }
|
||||
wzp-transport = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
bytes = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = "1"
|
||||
@@ -28,6 +28,8 @@ libc = "0.2"
|
||||
jni = { version = "0.21", default-features = false }
|
||||
rand = { workspace = true }
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring"] }
|
||||
[target.'cfg(target_os = "android")'.dependencies]
|
||||
tracing-android = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1"
|
||||
|
||||
@@ -65,9 +65,8 @@ fn main() {
|
||||
} else {
|
||||
"aarch64-linux-android"
|
||||
};
|
||||
let lib_dir = format!(
|
||||
"{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}"
|
||||
);
|
||||
let lib_dir =
|
||||
format!("{ndk}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/{arch}");
|
||||
println!("cargo:rustc-link-search=native={lib_dir}");
|
||||
|
||||
// Copy libc++_shared.so to the jniLibs directory
|
||||
@@ -82,9 +81,7 @@ fn main() {
|
||||
};
|
||||
// Try to copy to the Gradle jniLibs directory
|
||||
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
|
||||
let jni_dir = format!(
|
||||
"{manifest}/../../android/app/src/main/jniLibs/{jni_abi}"
|
||||
);
|
||||
let jni_dir = format!("{manifest}/../../android/app/src/main/jniLibs/{jni_abi}");
|
||||
if let Ok(_) = std::fs::create_dir_all(&jni_dir) {
|
||||
let _ = std::fs::copy(&shared_so, format!("{jni_dir}/libc++_shared.so"));
|
||||
println!("cargo:warning=Copied libc++_shared.so to {jni_dir}");
|
||||
@@ -127,7 +124,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
let oboe_dir = out_dir.join("oboe");
|
||||
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
return Some(oboe_dir);
|
||||
}
|
||||
|
||||
@@ -143,7 +145,12 @@ fn fetch_oboe() -> Option<PathBuf> {
|
||||
|
||||
match status {
|
||||
Ok(s) if s.success() => {
|
||||
if oboe_dir.join("include").join("oboe").join("Oboe.h").exists() {
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
Some(oboe_dir)
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -326,7 +326,10 @@ pub fn pin_to_big_core() {
|
||||
&set,
|
||||
);
|
||||
if ret != 0 {
|
||||
warn!("sched_setaffinity failed: {}", std::io::Error::last_os_error());
|
||||
warn!(
|
||||
"sched_setaffinity failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
} else {
|
||||
info!(start, num_cpus, "pinned to big cores");
|
||||
}
|
||||
|
||||
@@ -1,91 +1,130 @@
|
||||
//! Lock-free SPSC ring buffers for audio PCM transfer between
|
||||
//! Kotlin AudioRecord/AudioTrack threads and the Rust engine.
|
||||
//! Lock-free SPSC ring buffer — "Reader-Detects-Lap" architecture.
|
||||
//!
|
||||
//! These use a simple spin-free design: the producer writes and advances
|
||||
//! a write cursor, the consumer reads and advances a read cursor.
|
||||
//! Both cursors are atomic so no mutex is needed.
|
||||
//! SPSC invariant: the producer ONLY writes `write_pos`, the consumer
|
||||
//! ONLY writes `read_pos`. Neither thread touches the other's cursor.
|
||||
//!
|
||||
//! On overflow (writer laps the reader), the writer simply overwrites
|
||||
//! old buffer data. The reader detects the lap via `available() >
|
||||
//! RING_CAPACITY` and snaps its own `read_pos` forward.
|
||||
//!
|
||||
//! Capacity is a power of 2 for bitmask indexing (no modulo).
|
||||
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
|
||||
/// Ring buffer capacity in i16 samples.
|
||||
/// 960 samples * 10 frames = ~200ms of audio at 48kHz mono.
|
||||
const RING_CAPACITY: usize = 960 * 10;
|
||||
/// Ring buffer capacity — power of 2 for bitmask indexing.
|
||||
/// 16384 samples = 341.3ms at 48kHz mono. 70% more headroom
|
||||
/// than the previous 9600 (200ms) for surviving Android GC pauses.
|
||||
const RING_CAPACITY: usize = 16384; // 2^14
|
||||
const RING_MASK: usize = RING_CAPACITY - 1;
|
||||
|
||||
/// Lock-free single-producer single-consumer ring buffer for i16 PCM samples.
|
||||
pub struct AudioRing {
|
||||
buf: Box<[i16; RING_CAPACITY]>,
|
||||
buf: Box<[i16]>,
|
||||
/// Monotonically increasing write cursor. ONLY written by producer.
|
||||
write_pos: AtomicUsize,
|
||||
/// Monotonically increasing read cursor. ONLY written by consumer.
|
||||
read_pos: AtomicUsize,
|
||||
/// Incremented by reader when it detects it was lapped (overflow).
|
||||
overflow_count: AtomicU64,
|
||||
/// Incremented by reader when ring is empty (underrun).
|
||||
underrun_count: AtomicU64,
|
||||
}
|
||||
|
||||
// SAFETY: AudioRing is designed for SPSC — one thread writes, one reads.
|
||||
// The atomics ensure visibility. The buffer itself is never accessed
|
||||
// from the same index by both threads simultaneously because the
|
||||
// producer only writes to positions between write_pos and read_pos,
|
||||
// and the consumer only reads from positions between read_pos and write_pos.
|
||||
// SAFETY: AudioRing is SPSC — one thread writes (producer), one reads (consumer).
|
||||
// The producer only writes write_pos. The consumer only writes read_pos.
|
||||
// Neither thread writes the other's cursor. Buffer indices are derived from
|
||||
// the owning thread's cursor, ensuring no concurrent access to the same index.
|
||||
unsafe impl Send for AudioRing {}
|
||||
unsafe impl Sync for AudioRing {}
|
||||
|
||||
impl AudioRing {
|
||||
pub fn new() -> Self {
|
||||
debug_assert!(RING_CAPACITY.is_power_of_two());
|
||||
Self {
|
||||
buf: Box::new([0i16; RING_CAPACITY]),
|
||||
buf: vec![0i16; RING_CAPACITY].into_boxed_slice(),
|
||||
write_pos: AtomicUsize::new(0),
|
||||
read_pos: AtomicUsize::new(0),
|
||||
overflow_count: AtomicU64::new(0),
|
||||
underrun_count: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples available to read.
|
||||
/// Number of samples available to read (clamped to capacity).
|
||||
pub fn available(&self) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let r = self.read_pos.load(Ordering::Acquire);
|
||||
w.wrapping_sub(r)
|
||||
let r = self.read_pos.load(Ordering::Relaxed);
|
||||
w.wrapping_sub(r).min(RING_CAPACITY)
|
||||
}
|
||||
|
||||
/// Number of samples that can be written without overwriting.
|
||||
/// Number of samples that can be written without overwriting unread data.
|
||||
pub fn free_space(&self) -> usize {
|
||||
RING_CAPACITY - self.available()
|
||||
RING_CAPACITY.saturating_sub(self.available())
|
||||
}
|
||||
|
||||
/// Write samples into the ring. Returns number of samples written.
|
||||
/// Drops oldest samples if the ring is full.
|
||||
///
|
||||
/// If the ring is full, old data is silently overwritten. The reader
|
||||
/// will detect the lap and self-correct. The writer NEVER touches
|
||||
/// `read_pos` — this is the key invariant that prevents cursor desync.
|
||||
pub fn write(&self, samples: &[i16]) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Relaxed);
|
||||
let count = samples.len().min(RING_CAPACITY);
|
||||
let w = self.write_pos.load(Ordering::Relaxed);
|
||||
|
||||
for i in 0..count {
|
||||
let idx = (w + i) % RING_CAPACITY;
|
||||
// SAFETY: We're the only writer, and the reader won't read
|
||||
// past read_pos which we haven't advanced past yet.
|
||||
unsafe {
|
||||
let ptr = self.buf.as_ptr() as *mut i16;
|
||||
*ptr.add(idx) = samples[i];
|
||||
*ptr.add((w + i) & RING_MASK) = samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
self.write_pos.store(w.wrapping_add(count), Ordering::Release);
|
||||
|
||||
// If we overwrote unread data, advance read_pos
|
||||
if self.available() > RING_CAPACITY {
|
||||
let new_read = self.write_pos.load(Ordering::Relaxed).wrapping_sub(RING_CAPACITY);
|
||||
self.read_pos.store(new_read, Ordering::Release);
|
||||
}
|
||||
|
||||
self.write_pos
|
||||
.store(w.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Read samples from the ring into `out`. Returns number of samples read.
|
||||
///
|
||||
/// If the writer has lapped the reader (overflow), `read_pos` is snapped
|
||||
/// forward to the oldest valid data. This is safe because only the
|
||||
/// reader thread writes `read_pos`.
|
||||
pub fn read(&self, out: &mut [i16]) -> usize {
|
||||
let avail = self.available();
|
||||
let count = out.len().min(avail);
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let mut r = self.read_pos.load(Ordering::Relaxed);
|
||||
|
||||
let r = self.read_pos.load(Ordering::Relaxed);
|
||||
for i in 0..count {
|
||||
let idx = (r + i) % RING_CAPACITY;
|
||||
out[i] = unsafe { *self.buf.as_ptr().add(idx) };
|
||||
let mut avail = w.wrapping_sub(r);
|
||||
|
||||
// Lap detection: writer has overwritten our unread data.
|
||||
// Snap read_pos forward to oldest valid data in the buffer.
|
||||
if avail > RING_CAPACITY {
|
||||
r = w.wrapping_sub(RING_CAPACITY);
|
||||
avail = RING_CAPACITY;
|
||||
self.overflow_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
self.read_pos.store(r.wrapping_add(count), Ordering::Release);
|
||||
let count = out.len().min(avail);
|
||||
if count == 0 {
|
||||
if w == r {
|
||||
self.underrun_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
for i in 0..count {
|
||||
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
||||
}
|
||||
|
||||
self.read_pos
|
||||
.store(r.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Number of overflow events (reader was lapped by writer).
|
||||
pub fn overflow_count(&self) -> u64 {
|
||||
self.overflow_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Number of underrun events (reader found empty buffer).
|
||||
pub fn underrun_count(&self) -> u64 {
|
||||
self.underrun_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,4 +12,13 @@ pub enum EngineCommand {
|
||||
ForceProfile(QualityProfile),
|
||||
/// Stop the call and shut down the engine.
|
||||
Stop,
|
||||
/// Place a direct call to a fingerprint (requires signal connection).
|
||||
PlaceCall { target_fingerprint: String },
|
||||
/// Answer an incoming direct call.
|
||||
AnswerCall {
|
||||
call_id: String,
|
||||
accept_mode: wzp_proto::CallAcceptMode,
|
||||
},
|
||||
/// Reject an incoming direct call.
|
||||
RejectCall { call_id: String },
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,11 @@
|
||||
//! JNI bridge for Android — thin layer between Kotlin and the WzpEngine.
|
||||
|
||||
use std::panic;
|
||||
use std::sync::Once;
|
||||
|
||||
use jni::JNIEnv;
|
||||
use jni::objects::{JClass, JObject, JString};
|
||||
use jni::sys::{jboolean, jint, jlong, jstring};
|
||||
use jni::JNIEnv;
|
||||
use tracing::{error, info};
|
||||
use wzp_proto::QualityProfile;
|
||||
|
||||
@@ -20,20 +21,72 @@ unsafe fn handle_ref(handle: jlong) -> &'static mut EngineHandle {
|
||||
unsafe { &mut *(handle as *mut EngineHandle) }
|
||||
}
|
||||
|
||||
/// 7 = auto (use relay's chosen profile)
|
||||
const PROFILE_AUTO: jint = 7;
|
||||
|
||||
fn profile_from_int(value: jint) -> QualityProfile {
|
||||
match value {
|
||||
1 => QualityProfile::DEGRADED,
|
||||
2 => QualityProfile::CATASTROPHIC,
|
||||
_ => QualityProfile::GOOD,
|
||||
0 => QualityProfile::GOOD, // Opus 24k
|
||||
1 => QualityProfile::DEGRADED, // Opus 6k
|
||||
2 => QualityProfile::CATASTROPHIC, // Codec2 1.2k
|
||||
3 => QualityProfile {
|
||||
// Codec2 3.2k
|
||||
codec: wzp_proto::CodecId::Codec2_3200,
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
..QualityProfile::GOOD
|
||||
},
|
||||
4 => QualityProfile::STUDIO_32K, // Opus 32k
|
||||
5 => QualityProfile::STUDIO_48K, // Opus 48k
|
||||
6 => QualityProfile::STUDIO_64K, // Opus 64k
|
||||
_ => QualityProfile::GOOD, // auto falls back to GOOD
|
||||
}
|
||||
}
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize tracing → Android logcat (tag "wzp_android").
|
||||
/// Safe to call multiple times — only the first call takes effect.
|
||||
fn init_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
#[cfg(target_os = "android")]
|
||||
{
|
||||
// Wrap in catch_unwind — sharded_slab allocation inside
|
||||
// tracing_subscriber::registry() can crash on some Android
|
||||
// devices if scudo malloc fails during early initialization.
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
if let Ok(layer) = tracing_android::layer("wzp_android") {
|
||||
// Filter: INFO for our crates, WARN for everything else.
|
||||
// The jni crate emits VERBOSE logs for every method lookup
|
||||
// (~10 lines per JNI call, 100+ calls/sec) which floods logcat
|
||||
// and causes the system to kill the app.
|
||||
let filter = EnvFilter::new("warn,wzp_android=info,wzp_proto=info,wzp_transport=info,wzp_codec=info,wzp_fec=info,wzp_crypto=info");
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(layer)
|
||||
.with(filter)
|
||||
.try_init();
|
||||
}
|
||||
});
|
||||
}
|
||||
#[cfg(not(target_os = "android"))]
|
||||
{
|
||||
// On non-Android targets tracing-android is unavailable.
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeInit(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
) -> jlong {
|
||||
let result = panic::catch_unwind(|| {
|
||||
init_logging();
|
||||
let handle = Box::new(EngineHandle {
|
||||
engine: WzpEngine::new(),
|
||||
});
|
||||
@@ -54,12 +107,30 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
||||
room_j: JString,
|
||||
seed_hex_j: JString,
|
||||
token_j: JString,
|
||||
alias_j: JString,
|
||||
profile_j: jint,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let relay_addr: String = env.get_string(&relay_addr_j).map(|s| s.into()).unwrap_or_default();
|
||||
let room: String = env.get_string(&room_j).map(|s| s.into()).unwrap_or_default();
|
||||
let seed_hex: String = env.get_string(&seed_hex_j).map(|s| s.into()).unwrap_or_default();
|
||||
let token: String = env.get_string(&token_j).map(|s| s.into()).unwrap_or_default();
|
||||
let relay_addr: String = env
|
||||
.get_string(&relay_addr_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let room: String = env
|
||||
.get_string(&room_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let seed_hex: String = env
|
||||
.get_string(&seed_hex_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let token: String = env
|
||||
.get_string(&token_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let alias: String = env
|
||||
.get_string(&alias_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
|
||||
@@ -78,11 +149,17 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartCall(
|
||||
}
|
||||
|
||||
let config = CallStartConfig {
|
||||
profile: QualityProfile::GOOD,
|
||||
profile: profile_from_int(profile_j),
|
||||
auto_profile: profile_j == PROFILE_AUTO,
|
||||
relay_addr,
|
||||
room,
|
||||
auth_token: if token.is_empty() { Vec::new() } else { token.into_bytes() },
|
||||
auth_token: if token.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
token.into_bytes()
|
||||
},
|
||||
identity_seed,
|
||||
alias: if alias.is_empty() { None } else { Some(alias) },
|
||||
};
|
||||
|
||||
match h.engine.start_call(config) {
|
||||
@@ -174,6 +251,30 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeForceProfile(
|
||||
}));
|
||||
}
|
||||
|
||||
/// Signal a network transport change from the Android ConnectivityManager.
|
||||
///
|
||||
/// `network_type` matches the Rust `NetworkContext` enum:
|
||||
/// 0=WiFi, 1=CellularLte, 2=Cellular5g, 3=Cellular3g, 4=Unknown, 5=None
|
||||
///
|
||||
/// The engine forwards this to the `AdaptiveQualityController` which:
|
||||
/// - Preemptively downgrades one tier on WiFi→cellular
|
||||
/// - Activates a 10-second FEC boost
|
||||
/// - Uses faster downgrade thresholds on cellular
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeOnNetworkChanged(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
network_type: jint,
|
||||
bandwidth_kbps: jint,
|
||||
) {
|
||||
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
h.engine
|
||||
.on_network_changed(network_type as u8, bandwidth_kbps as u32);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Write captured PCM samples from Kotlin AudioRecord into the engine's capture ring.
|
||||
/// pcm is a Java short[] array.
|
||||
#[unsafe(no_mangle)]
|
||||
@@ -190,7 +291,6 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudio(
|
||||
return 0;
|
||||
}
|
||||
let mut buf = vec![0i16; len];
|
||||
// GetShortArrayRegion copies Java array into our buffer
|
||||
if env.get_short_array_region(&pcm, 0, &mut buf).is_err() {
|
||||
return 0;
|
||||
}
|
||||
@@ -224,6 +324,58 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudio(
|
||||
result.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Write captured PCM from a DirectByteBuffer — zero JNI array copies.
|
||||
/// The ByteBuffer must contain little-endian i16 samples.
|
||||
/// Called from the AudioRecord capture thread.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeWriteAudioDirect(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
buffer: jni::objects::JByteBuffer,
|
||||
sample_count: jint,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let ptr = env
|
||||
.get_direct_buffer_address(&buffer)
|
||||
.unwrap_or(std::ptr::null_mut());
|
||||
if ptr.is_null() || sample_count <= 0 {
|
||||
return 0;
|
||||
}
|
||||
let samples =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const i16, sample_count as usize) };
|
||||
h.engine.write_audio(samples) as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Read decoded PCM into a DirectByteBuffer — zero JNI array copies.
|
||||
/// The ByteBuffer will be filled with little-endian i16 samples.
|
||||
/// Called from the AudioTrack playout thread.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeReadAudioDirect(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
buffer: jni::objects::JByteBuffer,
|
||||
max_samples: jint,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let ptr = env
|
||||
.get_direct_buffer_address(&buffer)
|
||||
.unwrap_or(std::ptr::null_mut());
|
||||
if ptr.is_null() || max_samples <= 0 {
|
||||
return 0;
|
||||
}
|
||||
let samples =
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr as *mut i16, max_samples as usize) };
|
||||
h.engine.read_audio(samples) as jint
|
||||
}));
|
||||
result.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeDestroy(
|
||||
_env: JNIEnv,
|
||||
@@ -235,3 +387,155 @@ pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeDestroy(
|
||||
drop(h);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Ping a relay server — instance method, requires engine handle.
|
||||
/// Returns JSON `{"rtt_ms":N,"server_fingerprint":"hex"}` or null on failure.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePingRelay<'a>(
|
||||
mut env: JNIEnv<'a>,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
relay_j: JString,
|
||||
) -> jstring {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let relay: String = env
|
||||
.get_string(&relay_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
match h.engine.ping_relay(&relay) {
|
||||
Ok(json) => Some(json),
|
||||
Err(_) => None,
|
||||
}
|
||||
}));
|
||||
|
||||
let json = match result {
|
||||
Ok(Some(s)) => s,
|
||||
_ => return JObject::null().into_raw(),
|
||||
};
|
||||
env.new_string(&json)
|
||||
.map(|s| s.into_raw())
|
||||
.unwrap_or(JObject::null().into_raw())
|
||||
}
|
||||
|
||||
// ── Direct calling JNI functions ──
|
||||
|
||||
/// Start persistent signaling connection to relay for direct calls.
|
||||
/// Returns 0 on success, -1 on error.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeStartSignaling<'a>(
|
||||
mut env: JNIEnv<'a>,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
relay_addr_j: JString,
|
||||
seed_hex_j: JString,
|
||||
token_j: JString,
|
||||
alias_j: JString,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let relay_addr: String = env
|
||||
.get_string(&relay_addr_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let seed_hex: String = env
|
||||
.get_string(&seed_hex_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let token: String = env
|
||||
.get_string(&token_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let alias: String = env
|
||||
.get_string(&alias_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
|
||||
h.engine.start_signaling(
|
||||
&relay_addr,
|
||||
&seed_hex,
|
||||
if token.is_empty() { None } else { Some(&token) },
|
||||
if alias.is_empty() { None } else { Some(&alias) },
|
||||
)
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => {
|
||||
error!("start_signaling failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("start_signaling panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Place a direct call to a target fingerprint.
|
||||
/// Returns 0 on success, -1 on error.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativePlaceCall<'a>(
|
||||
mut env: JNIEnv<'a>,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
target_fp_j: JString,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let target: String = env
|
||||
.get_string(&target_fp_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
h.engine.place_call(&target)
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => {
|
||||
error!("place_call failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("place_call panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Answer an incoming direct call.
|
||||
/// mode: 0=Reject, 1=AcceptTrusted, 2=AcceptGeneric
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "system" fn Java_com_wzp_engine_WzpEngine_nativeAnswerCall<'a>(
|
||||
mut env: JNIEnv<'a>,
|
||||
_class: JClass,
|
||||
handle: jlong,
|
||||
call_id_j: JString,
|
||||
mode: jint,
|
||||
) -> jint {
|
||||
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
let h = unsafe { handle_ref(handle) };
|
||||
let call_id: String = env
|
||||
.get_string(&call_id_j)
|
||||
.map(|s| s.into())
|
||||
.unwrap_or_default();
|
||||
let accept_mode = match mode {
|
||||
0 => wzp_proto::CallAcceptMode::Reject,
|
||||
1 => wzp_proto::CallAcceptMode::AcceptTrusted,
|
||||
_ => wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||
};
|
||||
h.engine.answer_call(&call_id, accept_mode)
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => 0,
|
||||
Ok(Err(e)) => {
|
||||
error!("answer_call failed: {e}");
|
||||
-1
|
||||
}
|
||||
Err(_) => {
|
||||
error!("answer_call panicked");
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,24 @@
|
||||
//!
|
||||
//! On non-Android targets, the Oboe C++ layer compiles as a stub,
|
||||
//! allowing `cargo check` and unit tests on the host.
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! **Dead code as of the Tauri mobile rewrite.** The legacy Kotlin+JNI
|
||||
//! Android app that consumed this crate was replaced by a Tauri 2.x
|
||||
//! Mobile app (see `desktop/src-tauri/src/engine.rs` for the live
|
||||
//! Android audio recv path and `crates/wzp-native/` for the Oboe
|
||||
//! bridge). We keep this crate in the workspace for reference and to
|
||||
//! preserve the commit history, but it is not built by any shipping
|
||||
//! target. Allow the accumulated leftover warnings so CI/workspace
|
||||
//! checks stay clean — any real cleanup should happen as part of
|
||||
//! removing the crate entirely, not piecemeal.
|
||||
#![allow(dead_code, unused_imports, unused_variables, unused_mut)]
|
||||
|
||||
pub mod audio_android;
|
||||
pub mod audio_ring;
|
||||
pub mod commands;
|
||||
pub mod engine;
|
||||
pub mod jni_bridge;
|
||||
pub mod pipeline;
|
||||
pub mod stats;
|
||||
pub mod jni_bridge;
|
||||
|
||||
@@ -9,8 +9,8 @@ use wzp_codec::{AdaptiveDecoder, AdaptiveEncoder, AutoGainControl, EchoCanceller
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::jitter::{JitterBuffer, PlayoutResult};
|
||||
use wzp_proto::quality::AdaptiveQualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::traits::QualityController;
|
||||
use wzp_proto::traits::{AudioDecoder, AudioEncoder, FecDecoder, FecEncoder};
|
||||
use wzp_proto::{MediaPacket, QualityProfile};
|
||||
|
||||
use crate::audio_android::FRAME_SAMPLES;
|
||||
@@ -58,14 +58,12 @@ pub struct Pipeline {
|
||||
impl Pipeline {
|
||||
/// Create a new pipeline configured for the given quality profile.
|
||||
pub fn new(profile: QualityProfile) -> Result<Self, anyhow::Error> {
|
||||
let encoder = AdaptiveEncoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||
let decoder = AdaptiveDecoder::new(profile)
|
||||
.map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||
let fec_encoder =
|
||||
RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||
let fec_decoder =
|
||||
RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||
let encoder =
|
||||
AdaptiveEncoder::new(profile).map_err(|e| anyhow::anyhow!("encoder init: {e}"))?;
|
||||
let decoder =
|
||||
AdaptiveDecoder::new(profile).map_err(|e| anyhow::anyhow!("decoder init: {e}"))?;
|
||||
let fec_encoder = RaptorQFecEncoder::with_defaults(profile.frames_per_block as usize);
|
||||
let fec_decoder = RaptorQFecDecoder::with_defaults(profile.frames_per_block as usize);
|
||||
let jitter_buffer = JitterBuffer::new(10, 250, 3);
|
||||
let quality_ctrl = AdaptiveQualityController::new();
|
||||
|
||||
@@ -136,11 +134,11 @@ impl Pipeline {
|
||||
pub fn feed_packet(&mut self, packet: MediaPacket) {
|
||||
// Feed FEC symbols if present
|
||||
let header = &packet.header;
|
||||
if header.fec_block != 0 || header.fec_symbol != 0 {
|
||||
let is_repair = header.is_repair;
|
||||
if header.fec_block != 0 {
|
||||
let is_repair = header.is_repair();
|
||||
if let Err(e) = self.fec_decoder.add_symbol(
|
||||
header.fec_block,
|
||||
header.fec_symbol,
|
||||
header.fec_block as u8,
|
||||
header.fec_block >> 8,
|
||||
is_repair,
|
||||
&packet.payload,
|
||||
) {
|
||||
@@ -211,10 +209,7 @@ impl Pipeline {
|
||||
///
|
||||
/// Returns a new profile if a tier transition occurred.
|
||||
#[allow(unused)]
|
||||
pub fn observe_quality(
|
||||
&mut self,
|
||||
report: &wzp_proto::QualityReport,
|
||||
) -> Option<QualityProfile> {
|
||||
pub fn observe_quality(&mut self, report: &wzp_proto::QualityReport) -> Option<QualityProfile> {
|
||||
let new_profile = self.quality_ctrl.observe(report);
|
||||
if let Some(ref profile) = new_profile {
|
||||
if let Err(e) = self.encoder.set_profile(*profile) {
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum CallState {
|
||||
Active,
|
||||
Reconnecting,
|
||||
Closed,
|
||||
/// Connected to relay signal channel, registered for direct calls.
|
||||
Registered,
|
||||
/// Outgoing call ringing on callee's side.
|
||||
Ringing,
|
||||
/// Incoming call received, waiting for user to accept/reject.
|
||||
IncomingCall,
|
||||
}
|
||||
|
||||
impl serde::Serialize for CallState {
|
||||
@@ -21,6 +27,9 @@ impl serde::Serialize for CallState {
|
||||
CallState::Active => 2,
|
||||
CallState::Reconnecting => 3,
|
||||
CallState::Closed => 4,
|
||||
CallState::Registered => 5,
|
||||
CallState::Ringing => 6,
|
||||
CallState::IncomingCall => 7,
|
||||
};
|
||||
serializer.serialize_u8(n)
|
||||
}
|
||||
@@ -49,14 +58,46 @@ pub struct CallStats {
|
||||
pub frames_decoded: u64,
|
||||
/// Number of playout underruns (buffer empty when audio needed).
|
||||
pub underruns: u64,
|
||||
/// Frames recovered by FEC.
|
||||
/// Frames recovered by RaptorQ FEC (Codec2 tiers only; Opus bypasses
|
||||
/// RaptorQ per Phase 2).
|
||||
pub fec_recovered: u64,
|
||||
/// Phase 3c: Opus frames reconstructed via DRED side-channel data.
|
||||
/// Only increments on the Opus tiers; always zero for Codec2.
|
||||
pub dred_reconstructions: u64,
|
||||
/// Phase 3c: Opus frames filled via classical Opus PLC because no DRED
|
||||
/// state covered the gap, plus any decode-error fallbacks. Codec2 loss
|
||||
/// also increments this counter via the Codec2 PLC path.
|
||||
pub classical_plc_invocations: u64,
|
||||
/// Playout ring overflow count (reader was lapped by writer).
|
||||
pub playout_overflows: u64,
|
||||
/// Playout ring underrun count (reader found empty buffer).
|
||||
pub playout_underruns: u64,
|
||||
/// Capture ring overflow count.
|
||||
pub capture_overflows: u64,
|
||||
/// Current mic audio level (RMS of i16 samples, 0-32767).
|
||||
pub audio_level: u32,
|
||||
/// Our current outgoing codec name (e.g. "Opus24k", "Codec2_1200").
|
||||
pub current_codec: String,
|
||||
/// Last seen incoming codec from other participants.
|
||||
pub peer_codec: String,
|
||||
/// Whether auto quality mode is active.
|
||||
pub auto_mode: bool,
|
||||
/// Number of participants in the room (from last RoomUpdate).
|
||||
pub room_participant_count: u32,
|
||||
/// Participant list (fingerprint + optional alias) serialized as JSON array.
|
||||
pub room_participants: Vec<RoomMember>,
|
||||
/// SAS code for verbal verification (None if not in a call).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sas_code: Option<u32>,
|
||||
/// Incoming call info (present when state == IncomingCall).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub incoming_call_id: Option<String>,
|
||||
/// Fingerprint of the caller (present when state == IncomingCall).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub incoming_caller_fp: Option<String>,
|
||||
/// Alias of the caller (present when state == IncomingCall).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub incoming_caller_alias: Option<String>,
|
||||
}
|
||||
|
||||
/// A room member entry, serialized into the stats JSON.
|
||||
@@ -64,4 +105,5 @@ pub struct CallStats {
|
||||
pub struct RoomMember {
|
||||
pub fingerprint: String,
|
||||
pub alias: Option<String>,
|
||||
pub relay_label: Option<String>,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ wzp-codec = { workspace = true }
|
||||
wzp-fec = { workspace = true }
|
||||
wzp-crypto = { workspace = true }
|
||||
wzp-transport = { workspace = true }
|
||||
wzp-video = { path = "../wzp-video" }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@@ -21,17 +22,93 @@ anyhow = "1"
|
||||
serde = { workspace = true }
|
||||
serde_json = "1"
|
||||
chrono = "0.4"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
ratatui = "0.29"
|
||||
crossterm = "0.28"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
cpal = { version = "0.15", optional = true }
|
||||
libc = "0.2"
|
||||
# Phase 5.5 — LAN host-candidate ICE: enumerate local network
|
||||
# interface addresses for inclusion in DirectCallOffer/Answer so
|
||||
# peers on the same LAN can direct-connect without NAT hairpinning
|
||||
# through the WAN reflex addr (which many consumer NATs, including
|
||||
# MikroTik's default masquerade, don't support).
|
||||
if-addrs = "0.13"
|
||||
rand = { workspace = true }
|
||||
socket2 = "0.5"
|
||||
|
||||
# coreaudio-rs is Apple-framework-only; gate it to macOS so enabling
|
||||
# the `vpio` feature from a non-macOS target builds cleanly instead of
|
||||
# pulling in a crate that can only link against Apple frameworks.
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
coreaudio-rs = { version = "0.11", optional = true }
|
||||
|
||||
# Windows-only: direct WASAPI bindings for the `windows-aec` feature.
|
||||
# `windows` is Microsoft's official Rust COM bindings crate. We pull in
|
||||
# only the audio + COM subfeatures we need — the crate is organized as
|
||||
# a massive optional-feature tree, so enabling just these keeps compile
|
||||
# times reasonable (~5s for these features vs ~60s for the full crate).
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
windows = { version = "0.58", optional = true, features = [
|
||||
"Win32_Foundation",
|
||||
"Win32_Media_Audio",
|
||||
"Win32_Security",
|
||||
"Win32_System_Com",
|
||||
"Win32_System_Com_StructuredStorage",
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Variant",
|
||||
] }
|
||||
|
||||
# Linux-only: WebRTC AEC (Audio Processing Module) bindings for the
|
||||
# `linux-aec` feature. This is the 0.3.x line of the `tonarino/
|
||||
# webrtc-audio-processing` crate, which links against Debian's
|
||||
# `libwebrtc-audio-processing-dev` apt package (0.3-1+b1 on Bookworm).
|
||||
#
|
||||
# Note: we attempted the 2.x line with its `bundled` sub-feature first
|
||||
# (which would give us AEC3 instead of AEC2), but both the crates.io
|
||||
# tarball AND the upstream git `main` branch of webrtc-audio-processing-sys
|
||||
# 2.0.3 hit a `meson setup --reconfigure` bug where the build.rs passes
|
||||
# --reconfigure unconditionally even on first-run empty build dirs,
|
||||
# causing the bundled build to fail with "Directory does not contain a
|
||||
# valid build tree". The 0.x line doesn't use bundled mode and sidesteps
|
||||
# this entirely by linking the apt-provided library. AEC2 is older than
|
||||
# AEC3 but still the same algorithm family — this is what PulseAudio's
|
||||
# module-echo-cancel and PipeWire's filter-chain use by default on
|
||||
# current Debian-family distros.
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
webrtc-audio-processing = { version = "0.3", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
audio = ["cpal"]
|
||||
# vpio enables coreaudio-rs but that dep is itself gated to macOS above,
|
||||
# so enabling this feature on Windows/Linux is a no-op (the audio_vpio
|
||||
# module is also #[cfg(target_os = "macos")] in lib.rs).
|
||||
vpio = ["dep:coreaudio-rs"]
|
||||
# windows-aec enables a direct WASAPI capture backend that opens the
|
||||
# microphone under AudioCategory_Communications, turning on Windows's
|
||||
# OS-level communications audio processing (AEC + noise suppression +
|
||||
# AGC). The `windows` dep is itself target-gated to Windows above, so
|
||||
# enabling this feature on non-Windows targets is a no-op (the
|
||||
# audio_wasapi module is also #[cfg(target_os = "windows")] in lib.rs).
|
||||
windows-aec = ["dep:windows"]
|
||||
# linux-aec enables a CPAL + WebRTC AEC3 capture/playback backend that
|
||||
# runs the WebRTC Audio Processing Module (same algo as Chrome / Zoom /
|
||||
# Teams) in-process, using the playback PCM as the reference signal for
|
||||
# echo cancellation. The webrtc-audio-processing dep is target-gated to
|
||||
# Linux above, so enabling this feature on non-Linux targets is a no-op
|
||||
# (the audio_linux_aec module is also #[cfg(target_os = "linux")] in
|
||||
# lib.rs).
|
||||
linux-aec = ["dep:webrtc-audio-processing"]
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-client"
|
||||
path = "src/cli.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-analyzer"
|
||||
path = "src/analyzer.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "wzp-bench"
|
||||
path = "src/bench_cli.rs"
|
||||
|
||||
973
crates/wzp-client/src/analyzer.rs
Normal file
973
crates/wzp-client/src/analyzer.rs
Normal file
@@ -0,0 +1,973 @@
|
||||
//! WarzonePhone Protocol Analyzer — passive call quality observer.
|
||||
//!
|
||||
//! Joins a relay room as a passive participant (no media sent) and displays
|
||||
//! real-time per-participant quality metrics in a terminal UI.
|
||||
//!
|
||||
//! Usage:
|
||||
//! wzp-analyzer 127.0.0.1:4433 --room test
|
||||
//! wzp-analyzer 1.2.3.4:4433 --room test --capture session.wzp
|
||||
//! wzp-analyzer 1.2.3.4:4433 --room test --no-tui --duration 60
|
||||
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use clap::Parser;
|
||||
use tracing::info;
|
||||
|
||||
use wzp_proto::{CodecId, MediaPacket, MediaTransport, default_signal_version};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CLI
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// WarzonePhone Protocol Analyzer — passive call quality observer
|
||||
#[derive(Parser)]
|
||||
#[command(name = "wzp-analyzer", version)]
|
||||
struct Args {
|
||||
/// Relay address (host:port) — required for live mode, ignored with --replay
|
||||
relay: Option<String>,
|
||||
|
||||
/// Room name to observe — required for live mode, ignored with --replay
|
||||
#[arg(short, long)]
|
||||
room: Option<String>,
|
||||
|
||||
/// Auth token for relay
|
||||
#[arg(long)]
|
||||
token: Option<String>,
|
||||
|
||||
/// Identity seed (64-char hex)
|
||||
#[arg(long)]
|
||||
seed: Option<String>,
|
||||
|
||||
/// Capture packets to file
|
||||
#[arg(long)]
|
||||
capture: Option<String>,
|
||||
|
||||
/// Auto-stop after N seconds
|
||||
#[arg(long)]
|
||||
duration: Option<u64>,
|
||||
|
||||
/// Disable TUI (print stats to stdout instead)
|
||||
#[arg(long)]
|
||||
no_tui: bool,
|
||||
|
||||
/// Replay a captured .wzp file (offline analysis)
|
||||
#[arg(long)]
|
||||
replay: Option<String>,
|
||||
|
||||
/// Generate HTML report (from live session or replay)
|
||||
#[arg(long)]
|
||||
html: Option<String>,
|
||||
|
||||
/// Session key hex for decrypting payloads (enables audio decode)
|
||||
// TODO(#17): Audio decode requires session key + nonce context.
|
||||
// In SFU mode, payloads are E2E encrypted. Decoding requires
|
||||
// either: (a) session key from both endpoints, or (b) running
|
||||
// the analyzer as a trusted participant with its own key exchange.
|
||||
// For now, header-only analysis provides loss%, jitter, codec stats.
|
||||
#[arg(long)]
|
||||
key: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-participant statistics
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ParticipantStats {
|
||||
/// Stream identifier (index, assigned when we detect a new seq stream)
|
||||
stream_id: usize,
|
||||
/// Display name from RoomUpdate (if available)
|
||||
alias: Option<String>,
|
||||
/// Current codec
|
||||
codec: CodecId,
|
||||
/// Total packets received
|
||||
packets: u64,
|
||||
/// Detected lost packets (sequence gaps)
|
||||
lost: u64,
|
||||
/// Last seen sequence number
|
||||
last_seq: u32,
|
||||
/// Whether we've seen the first packet (for gap detection)
|
||||
seq_initialized: bool,
|
||||
/// EWMA jitter in ms
|
||||
jitter_ms: f64,
|
||||
/// Last packet arrival time
|
||||
last_arrival: Option<Instant>,
|
||||
/// Codec changes observed
|
||||
codec_switches: u32,
|
||||
/// First packet time
|
||||
first_seen: Instant,
|
||||
/// Last packet time
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
impl ParticipantStats {
|
||||
fn new(id: usize, codec: CodecId) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
stream_id: id,
|
||||
alias: None,
|
||||
codec,
|
||||
packets: 0,
|
||||
lost: 0,
|
||||
last_seq: 0,
|
||||
seq_initialized: false,
|
||||
jitter_ms: 0.0,
|
||||
last_arrival: None,
|
||||
codec_switches: 0,
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn ingest(&mut self, pkt: &MediaPacket, now: Instant) {
|
||||
self.packets += 1;
|
||||
self.last_seen = now;
|
||||
|
||||
// Codec switch detection
|
||||
if pkt.header.codec_id != self.codec {
|
||||
self.codec_switches += 1;
|
||||
self.codec = pkt.header.codec_id;
|
||||
}
|
||||
|
||||
// Loss detection from sequence gaps
|
||||
if self.seq_initialized {
|
||||
let expected = self.last_seq.wrapping_add(1);
|
||||
let gap = pkt.header.seq.wrapping_sub(expected);
|
||||
if gap > 0 && gap < 100 {
|
||||
self.lost += gap as u64;
|
||||
}
|
||||
}
|
||||
self.last_seq = pkt.header.seq;
|
||||
self.seq_initialized = true;
|
||||
|
||||
// Jitter (inter-arrival time variance, EWMA)
|
||||
if let Some(last) = self.last_arrival {
|
||||
let interval_ms = now.duration_since(last).as_secs_f64() * 1000.0;
|
||||
let expected_ms = pkt.header.codec_id.frame_duration_ms() as f64;
|
||||
let diff = (interval_ms - expected_ms).abs();
|
||||
self.jitter_ms = 0.1 * diff + 0.9 * self.jitter_ms;
|
||||
}
|
||||
self.last_arrival = Some(now);
|
||||
}
|
||||
|
||||
fn loss_percent(&self) -> f64 {
|
||||
let total = self.packets + self.lost;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(self.lost as f64 / total as f64) * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
fn duration(&self) -> Duration {
|
||||
self.last_seen.duration_since(self.first_seen)
|
||||
}
|
||||
|
||||
fn display_name(&self) -> String {
|
||||
self.alias
|
||||
.as_deref()
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("Stream {}", self.stream_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Participant identification by sequence stream
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Find the participant whose sequence counter is close to `seq`, or create a
|
||||
/// new one. Each sender has an independent wrapping u16 counter, so we can
|
||||
/// distinguish streams by proximity of consecutive sequence numbers.
|
||||
fn find_or_create_participant(
|
||||
participants: &mut Vec<ParticipantStats>,
|
||||
seq: u32,
|
||||
codec: CodecId,
|
||||
) -> usize {
|
||||
for (i, p) in participants.iter().enumerate() {
|
||||
if p.seq_initialized {
|
||||
let delta = seq.wrapping_sub(p.last_seq);
|
||||
if delta > 0 && delta < 50 {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
// New stream detected
|
||||
let id = participants.len();
|
||||
participants.push(ParticipantStats::new(id, codec));
|
||||
id
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Capture writer (binary packet log for later replay)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct CaptureWriter {
|
||||
file: std::io::BufWriter<std::fs::File>,
|
||||
start: Instant,
|
||||
}
|
||||
|
||||
impl CaptureWriter {
|
||||
fn new(path: &str, room: &str, relay: &str) -> anyhow::Result<Self> {
|
||||
let file = std::fs::File::create(path)?;
|
||||
let mut writer = std::io::BufWriter::new(file);
|
||||
// Magic + version
|
||||
writer.write_all(b"WZP\x01")?;
|
||||
let header = serde_json::json!({
|
||||
"room": room,
|
||||
"relay": relay,
|
||||
"start_time": chrono::Utc::now().to_rfc3339(),
|
||||
"version": 1,
|
||||
});
|
||||
let header_bytes = serde_json::to_vec(&header)?;
|
||||
writer.write_all(&(header_bytes.len() as u32).to_le_bytes())?;
|
||||
writer.write_all(&header_bytes)?;
|
||||
Ok(Self {
|
||||
file: writer,
|
||||
start: Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
fn write_packet(&mut self, pkt: &MediaPacket, now: Instant) -> anyhow::Result<()> {
|
||||
let elapsed_us = now.duration_since(self.start).as_micros() as u64;
|
||||
self.file.write_all(&elapsed_us.to_le_bytes())?;
|
||||
let raw = pkt.to_bytes();
|
||||
self.file.write_all(&(raw.len() as u32).to_le_bytes())?;
|
||||
self.file.write_all(&raw)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Capture reader (for replay mode)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct CaptureReader {
|
||||
reader: std::io::BufReader<std::fs::File>,
|
||||
header: serde_json::Value,
|
||||
}
|
||||
|
||||
impl CaptureReader {
|
||||
fn open(path: &str) -> anyhow::Result<Self> {
|
||||
use std::io::Read;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
|
||||
// Read magic
|
||||
let mut magic = [0u8; 4];
|
||||
reader.read_exact(&mut magic)?;
|
||||
anyhow::ensure!(&magic == b"WZP\x01", "not a WZP capture file");
|
||||
|
||||
// Read header
|
||||
let mut len_buf = [0u8; 4];
|
||||
reader.read_exact(&mut len_buf)?;
|
||||
let header_len = u32::from_le_bytes(len_buf) as usize;
|
||||
let mut header_bytes = vec![0u8; header_len];
|
||||
reader.read_exact(&mut header_bytes)?;
|
||||
let header: serde_json::Value = serde_json::from_slice(&header_bytes)?;
|
||||
|
||||
Ok(Self { reader, header })
|
||||
}
|
||||
|
||||
fn next_packet(&mut self) -> anyhow::Result<Option<(u64, MediaPacket)>> {
|
||||
use std::io::Read;
|
||||
// Read timestamp
|
||||
let mut ts_buf = [0u8; 8];
|
||||
match self.reader.read_exact(&mut ts_buf) {
|
||||
Ok(()) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
let timestamp_us = u64::from_le_bytes(ts_buf);
|
||||
|
||||
// Read packet
|
||||
let mut len_buf = [0u8; 4];
|
||||
self.reader.read_exact(&mut len_buf)?;
|
||||
let pkt_len = u32::from_le_bytes(len_buf) as usize;
|
||||
let mut pkt_bytes = vec![0u8; pkt_len];
|
||||
self.reader.read_exact(&mut pkt_bytes)?;
|
||||
|
||||
let pkt = MediaPacket::from_bytes(bytes::Bytes::from(pkt_bytes))
|
||||
.ok_or_else(|| anyhow::anyhow!("malformed packet in capture"))?;
|
||||
|
||||
Ok(Some((timestamp_us, pkt)))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Timeline entry (for HTML report generation)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct TimelineEntry {
|
||||
timestamp_us: u64,
|
||||
stream_id: usize,
|
||||
#[allow(dead_code)]
|
||||
codec: CodecId,
|
||||
#[allow(dead_code)]
|
||||
seq: u32,
|
||||
#[allow(dead_code)]
|
||||
payload_len: usize,
|
||||
loss_pct: f64,
|
||||
jitter_ms: f64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Replay mode (#15)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_replay(path: &str, args: &Args) -> anyhow::Result<()> {
|
||||
let mut reader = CaptureReader::open(path)?;
|
||||
eprintln!(
|
||||
"Replaying: {} (room: {})",
|
||||
path,
|
||||
reader
|
||||
.header
|
||||
.get("room")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("?")
|
||||
);
|
||||
|
||||
let mut participants: Vec<ParticipantStats> = Vec::new();
|
||||
let mut total_packets: u64 = 0;
|
||||
let start = Instant::now();
|
||||
let mut timeline: Vec<TimelineEntry> = Vec::new();
|
||||
|
||||
// Decrypt session from --key (optional)
|
||||
let mut decrypt_session: Option<wzp_crypto::ChaChaSession> =
|
||||
args.key.as_ref().and_then(|hex| {
|
||||
if hex.len() != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut key = [0u8; 32];
|
||||
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
|
||||
let s = std::str::from_utf8(chunk).unwrap_or("00");
|
||||
key[i] = u8::from_str_radix(s, 16).unwrap_or(0);
|
||||
}
|
||||
Some(wzp_crypto::ChaChaSession::new(key))
|
||||
});
|
||||
let mut decrypt_ok: u64 = 0;
|
||||
let mut decrypt_fail: u64 = 0;
|
||||
|
||||
while let Some((ts_us, pkt)) = reader.next_packet()? {
|
||||
let now = Instant::now();
|
||||
let idx =
|
||||
find_or_create_participant(&mut participants, pkt.header.seq, pkt.header.codec_id);
|
||||
participants[idx].ingest(&pkt, now);
|
||||
total_packets += 1;
|
||||
|
||||
// Attempt decryption if key provided
|
||||
if let Some(ref mut session) = decrypt_session {
|
||||
use wzp_proto::CryptoSession;
|
||||
let header_bytes = pkt.header.to_bytes();
|
||||
let mut plaintext = Vec::new();
|
||||
match session.decrypt(&header_bytes, &pkt.payload, &mut plaintext) {
|
||||
Ok(()) => {
|
||||
decrypt_ok += 1;
|
||||
if decrypt_ok <= 5 || decrypt_ok % 100 == 0 {
|
||||
eprintln!(
|
||||
" decrypt ok: seq={} codec={:?} payload={}B → plaintext={}B",
|
||||
pkt.header.seq,
|
||||
pkt.header.codec_id,
|
||||
pkt.payload.len(),
|
||||
plaintext.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
decrypt_fail += 1;
|
||||
if decrypt_fail <= 3 {
|
||||
eprintln!(
|
||||
" decrypt FAIL: seq={} (key mismatch, wrong direction, or rekey boundary)",
|
||||
pkt.header.seq
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record for HTML timeline
|
||||
timeline.push(TimelineEntry {
|
||||
timestamp_us: ts_us,
|
||||
stream_id: idx,
|
||||
codec: pkt.header.codec_id,
|
||||
seq: pkt.header.seq,
|
||||
payload_len: pkt.payload.len(),
|
||||
loss_pct: participants[idx].loss_percent(),
|
||||
jitter_ms: participants[idx].jitter_ms,
|
||||
});
|
||||
}
|
||||
|
||||
if decrypt_session.is_some() {
|
||||
eprintln!(
|
||||
"Decrypt stats: {} ok, {} failed (total {})",
|
||||
decrypt_ok, decrypt_fail, total_packets
|
||||
);
|
||||
}
|
||||
|
||||
print_summary(&participants, total_packets, start.elapsed());
|
||||
|
||||
// Generate HTML if requested
|
||||
if let Some(html_path) = &args.html {
|
||||
generate_html_report(
|
||||
html_path,
|
||||
&participants,
|
||||
&timeline,
|
||||
total_packets,
|
||||
&reader.header,
|
||||
)?;
|
||||
eprintln!("HTML report: {}", html_path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTML report generation (#16)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn generate_html_report(
|
||||
path: &str,
|
||||
participants: &[ParticipantStats],
|
||||
timeline: &[TimelineEntry],
|
||||
total_packets: u64,
|
||||
capture_header: &serde_json::Value,
|
||||
) -> anyhow::Result<()> {
|
||||
use std::io::Write as _;
|
||||
let mut f = std::fs::File::create(path)?;
|
||||
|
||||
let room = capture_header
|
||||
.get("room")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let start_time = capture_header
|
||||
.get("start_time")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("?");
|
||||
|
||||
// Build per-stream loss/jitter timeline data for Chart.js
|
||||
// Sample every 1 second (group timeline entries by second)
|
||||
let max_ts = timeline.last().map(|e| e.timestamp_us).unwrap_or(0);
|
||||
let duration_secs = (max_ts / 1_000_000) + 1;
|
||||
|
||||
let mut loss_data: std::collections::HashMap<usize, Vec<f64>> =
|
||||
std::collections::HashMap::new();
|
||||
let mut jitter_data: std::collections::HashMap<usize, Vec<f64>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for stream_id in 0..participants.len() {
|
||||
loss_data.insert(stream_id, vec![0.0; duration_secs as usize]);
|
||||
jitter_data.insert(stream_id, vec![0.0; duration_secs as usize]);
|
||||
}
|
||||
|
||||
for entry in timeline {
|
||||
let sec = (entry.timestamp_us / 1_000_000) as usize;
|
||||
if sec < duration_secs as usize {
|
||||
if let Some(losses) = loss_data.get_mut(&entry.stream_id) {
|
||||
losses[sec] = entry.loss_pct;
|
||||
}
|
||||
if let Some(jitters) = jitter_data.get_mut(&entry.stream_id) {
|
||||
jitters[sec] = entry.jitter_ms;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let colors = [
|
||||
"#e74c3c", "#3498db", "#2ecc71", "#f39c12", "#9b59b6", "#1abc9c",
|
||||
];
|
||||
|
||||
// Build dataset JSON for charts
|
||||
let mut loss_datasets = String::new();
|
||||
let mut jitter_datasets = String::new();
|
||||
for (i, p) in participants.iter().enumerate() {
|
||||
let name = p.display_name();
|
||||
let color = colors[i % colors.len()];
|
||||
let loss_vals = loss_data
|
||||
.get(&i)
|
||||
.map(|v| format!("{:?}", v))
|
||||
.unwrap_or_default();
|
||||
let jitter_vals = jitter_data
|
||||
.get(&i)
|
||||
.map(|v| format!("{:?}", v))
|
||||
.unwrap_or_default();
|
||||
|
||||
loss_datasets.push_str(&format!(
|
||||
"{{ label: '{}', data: {}, borderColor: '{}', fill: false }},\n",
|
||||
name, loss_vals, color
|
||||
));
|
||||
jitter_datasets.push_str(&format!(
|
||||
"{{ label: '{}', data: {}, borderColor: '{}', fill: false }},\n",
|
||||
name, jitter_vals, color
|
||||
));
|
||||
}
|
||||
|
||||
let labels: Vec<String> = (0..duration_secs).map(|s| format!("{}s", s)).collect();
|
||||
let labels_json = format!("{:?}", labels);
|
||||
|
||||
// Summary table rows
|
||||
let mut summary_rows = String::new();
|
||||
for p in participants {
|
||||
summary_rows.push_str(&format!(
|
||||
"<tr><td>{}</td><td>{:?}</td><td>{}</td><td>{:.1}%</td><td>{:.0}ms</td><td>{}</td></tr>\n",
|
||||
p.display_name(),
|
||||
p.codec,
|
||||
p.packets,
|
||||
p.loss_percent(),
|
||||
p.jitter_ms,
|
||||
p.codec_switches
|
||||
));
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
r#"<!DOCTYPE html>
|
||||
<html><head>
|
||||
<meta charset="utf-8">
|
||||
<title>WZP Call Report — {room}</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4"></script>
|
||||
<style>
|
||||
body {{ font-family: -apple-system, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; background: #1a1a2e; color: #e0e0e0; }}
|
||||
h1,h2 {{ color: #4a9eff; }}
|
||||
table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
|
||||
th,td {{ border: 1px solid #333; padding: 8px 12px; text-align: left; }}
|
||||
th {{ background: #16213e; }}
|
||||
tr:nth-child(even) {{ background: #1a1a3e; }}
|
||||
.chart-container {{ background: #16213e; border-radius: 8px; padding: 16px; margin: 20px 0; }}
|
||||
canvas {{ max-height: 300px; }}
|
||||
.meta {{ color: #888; font-size: 0.9em; }}
|
||||
</style>
|
||||
</head><body>
|
||||
<h1>WZP Call Quality Report</h1>
|
||||
<p class="meta">Room: <b>{room}</b> | Start: {start_time} | Packets: {total_packets} | Duration: {duration_secs}s</p>
|
||||
|
||||
<h2>Participant Summary</h2>
|
||||
<table>
|
||||
<tr><th>Name</th><th>Codec</th><th>Packets</th><th>Loss</th><th>Jitter</th><th>Codec Switches</th></tr>
|
||||
{summary_rows}
|
||||
</table>
|
||||
|
||||
<h2>Packet Loss Over Time</h2>
|
||||
<div class="chart-container"><canvas id="lossChart"></canvas></div>
|
||||
|
||||
<h2>Jitter Over Time</h2>
|
||||
<div class="chart-container"><canvas id="jitterChart"></canvas></div>
|
||||
|
||||
<script>
|
||||
const labels = {labels_json};
|
||||
new Chart(document.getElementById('lossChart'), {{
|
||||
type: 'line',
|
||||
data: {{ labels, datasets: [{loss_datasets}] }},
|
||||
options: {{ responsive: true, scales: {{ y: {{ beginAtZero: true, title: {{ display: true, text: 'Loss %' }} }} }} }}
|
||||
}});
|
||||
new Chart(document.getElementById('jitterChart'), {{
|
||||
type: 'line',
|
||||
data: {{ labels, datasets: [{jitter_datasets}] }},
|
||||
options: {{ responsive: true, scales: {{ y: {{ beginAtZero: true, title: {{ display: true, text: 'Jitter (ms)' }} }} }} }}
|
||||
}});
|
||||
</script>
|
||||
</body></html>"#
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// No-TUI mode (print stats to stdout periodically)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_no_tui(
|
||||
transport: &wzp_transport::QuinnTransport,
|
||||
participants: &mut Vec<ParticipantStats>,
|
||||
total_packets: &mut u64,
|
||||
deadline: Option<Instant>,
|
||||
mut capture_writer: Option<&mut CaptureWriter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut print_timer = Instant::now();
|
||||
loop {
|
||||
if let Some(dl) = deadline {
|
||||
if Instant::now() > dl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let now = Instant::now();
|
||||
let idx =
|
||||
find_or_create_participant(participants, pkt.header.seq, pkt.header.codec_id);
|
||||
participants[idx].ingest(&pkt, now);
|
||||
*total_packets += 1;
|
||||
if let Some(ref mut w) = capture_writer {
|
||||
w.write_packet(&pkt, now)?;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => break, // connection closed
|
||||
Ok(Err(e)) => {
|
||||
tracing::warn!("recv error: {e}");
|
||||
break;
|
||||
}
|
||||
Err(_) => {} // timeout, loop again
|
||||
}
|
||||
if print_timer.elapsed() >= Duration::from_secs(2) {
|
||||
print_stats(participants, *total_packets);
|
||||
print_timer = Instant::now();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_stats(participants: &[ParticipantStats], total: u64) {
|
||||
eprintln!(
|
||||
"--- {} participants | {} total packets ---",
|
||||
participants.len(),
|
||||
total
|
||||
);
|
||||
for p in participants {
|
||||
eprintln!(
|
||||
" {}: {} pkts, {:.1}% loss, {:.0}ms jitter, {:?}, {:.0}s",
|
||||
p.display_name(),
|
||||
p.packets,
|
||||
p.loss_percent(),
|
||||
p.jitter_ms,
|
||||
p.codec,
|
||||
p.duration().as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TUI mode (ratatui + crossterm)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_tui(
|
||||
transport: &wzp_transport::QuinnTransport,
|
||||
participants: &mut Vec<ParticipantStats>,
|
||||
total_packets: &mut u64,
|
||||
start_time: Instant,
|
||||
deadline: Option<Instant>,
|
||||
mut capture_writer: Option<&mut CaptureWriter>,
|
||||
) -> anyhow::Result<()> {
|
||||
crossterm::terminal::enable_raw_mode()?;
|
||||
let mut stdout = std::io::stdout();
|
||||
crossterm::execute!(stdout, crossterm::terminal::EnterAlternateScreen)?;
|
||||
let backend = ratatui::backend::CrosstermBackend::new(stdout);
|
||||
let mut terminal = ratatui::Terminal::new(backend)?;
|
||||
|
||||
let mut redraw_timer = Instant::now();
|
||||
|
||||
let result: anyhow::Result<()> = async {
|
||||
loop {
|
||||
// Check for quit key (q or Ctrl+C)
|
||||
if crossterm::event::poll(Duration::from_millis(0))? {
|
||||
if let crossterm::event::Event::Key(key) = crossterm::event::read()? {
|
||||
use crossterm::event::{KeyCode, KeyModifiers};
|
||||
if key.code == KeyCode::Char('q')
|
||||
|| (key.code == KeyCode::Char('c')
|
||||
&& key.modifiers.contains(KeyModifiers::CONTROL))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(dl) = deadline {
|
||||
if Instant::now() > dl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive packets (non-blocking with short timeout)
|
||||
match tokio::time::timeout(Duration::from_millis(20), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let now = Instant::now();
|
||||
let idx = find_or_create_participant(
|
||||
participants,
|
||||
pkt.header.seq,
|
||||
pkt.header.codec_id,
|
||||
);
|
||||
participants[idx].ingest(&pkt, now);
|
||||
*total_packets += 1;
|
||||
if let Some(ref mut w) = capture_writer {
|
||||
w.write_packet(&pkt, now)?;
|
||||
}
|
||||
}
|
||||
Ok(Ok(None)) => break,
|
||||
Ok(Err(e)) => {
|
||||
tracing::warn!("recv error: {e}");
|
||||
break;
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
|
||||
// Redraw TUI at ~10 FPS
|
||||
if redraw_timer.elapsed() >= Duration::from_millis(100) {
|
||||
terminal.draw(|f| draw_ui(f, participants, *total_packets, start_time))?;
|
||||
redraw_timer = Instant::now();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
// Always restore terminal, even on error
|
||||
crossterm::terminal::disable_raw_mode()?;
|
||||
crossterm::execute!(std::io::stdout(), crossterm::terminal::LeaveAlternateScreen)?;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn draw_ui(
|
||||
f: &mut ratatui::Frame,
|
||||
participants: &[ParticipantStats],
|
||||
total_packets: u64,
|
||||
start_time: Instant,
|
||||
) {
|
||||
use ratatui::layout::{Constraint, Direction, Layout};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::widgets::{Block, Borders, Paragraph, Row, Table};
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
let elapsed_str = format!(
|
||||
"{:02}:{:02}:{:02}",
|
||||
elapsed.as_secs() / 3600,
|
||||
(elapsed.as_secs() % 3600) / 60,
|
||||
elapsed.as_secs() % 60
|
||||
);
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(3), // header
|
||||
Constraint::Min(5), // participant table
|
||||
Constraint::Length(3), // footer
|
||||
])
|
||||
.split(f.area());
|
||||
|
||||
// Header
|
||||
let header = Paragraph::new(format!(
|
||||
" WZP Analyzer | {} participants | {} packets | {}",
|
||||
participants.len(),
|
||||
total_packets,
|
||||
elapsed_str
|
||||
))
|
||||
.block(
|
||||
Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title(" Protocol Analyzer "),
|
||||
);
|
||||
f.render_widget(header, chunks[0]);
|
||||
|
||||
// Participant table
|
||||
let header_row = Row::new(vec![
|
||||
"#", "Name", "Codec", "Packets", "Loss%", "Jitter", "Switches", "Duration",
|
||||
])
|
||||
.style(Style::default().add_modifier(Modifier::BOLD));
|
||||
|
||||
let rows: Vec<Row> = participants
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let loss_color = if p.loss_percent() > 5.0 {
|
||||
Color::Red
|
||||
} else if p.loss_percent() > 1.0 {
|
||||
Color::Yellow
|
||||
} else {
|
||||
Color::Green
|
||||
};
|
||||
|
||||
Row::new(vec![
|
||||
format!("{}", p.stream_id),
|
||||
p.display_name(),
|
||||
format!("{:?}", p.codec),
|
||||
format!("{}", p.packets),
|
||||
format!("{:.1}%", p.loss_percent()),
|
||||
format!("{:.0}ms", p.jitter_ms),
|
||||
format!("{}", p.codec_switches),
|
||||
format!("{:.0}s", p.duration().as_secs_f64()),
|
||||
])
|
||||
.style(Style::default().fg(loss_color))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let widths = [
|
||||
Constraint::Length(3), // #
|
||||
Constraint::Length(20), // Name
|
||||
Constraint::Length(12), // Codec
|
||||
Constraint::Length(10), // Packets
|
||||
Constraint::Length(8), // Loss%
|
||||
Constraint::Length(10), // Jitter
|
||||
Constraint::Length(10), // Switches
|
||||
Constraint::Length(10), // Duration
|
||||
];
|
||||
|
||||
let table = Table::new(rows, widths).header(header_row).block(
|
||||
Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title(" Participants "),
|
||||
);
|
||||
f.render_widget(table, chunks[1]);
|
||||
|
||||
// Footer
|
||||
let footer =
|
||||
Paragraph::new(" Press 'q' to quit ").block(Block::default().borders(Borders::ALL));
|
||||
f.render_widget(footer, chunks[2]);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Summary (printed on exit)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn print_summary(participants: &[ParticipantStats], total: u64, elapsed: Duration) {
|
||||
eprintln!("\n=== Session Summary ===");
|
||||
eprintln!(
|
||||
"Duration: {:.1}s | Total packets: {} | Participants: {}",
|
||||
elapsed.as_secs_f64(),
|
||||
total,
|
||||
participants.len()
|
||||
);
|
||||
for p in participants {
|
||||
eprintln!(
|
||||
" {}: {} pkts, {:.1}% loss, {:.0}ms jitter, {:?}, {} codec switches",
|
||||
p.display_name(),
|
||||
p.packets,
|
||||
p.loss_percent(),
|
||||
p.jitter_ms,
|
||||
p.codec,
|
||||
p.codec_switches,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// main
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Only init tracing subscriber in no-tui mode (it would corrupt the TUI otherwise)
|
||||
if args.no_tui || args.replay.is_some() {
|
||||
tracing_subscriber::fmt().init();
|
||||
}
|
||||
|
||||
let _crypto_session: Option<std::sync::Mutex<wzp_crypto::ChaChaSession>> =
|
||||
if let Some(ref key_hex) = args.key {
|
||||
if key_hex.len() != 64 {
|
||||
eprintln!(
|
||||
"Error: --key must be 64 hex characters (32 bytes). Got {} chars.",
|
||||
key_hex.len()
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let mut key_bytes = [0u8; 32];
|
||||
for (i, chunk) in key_hex.as_bytes().chunks(2).enumerate() {
|
||||
let hex_str = std::str::from_utf8(chunk).unwrap_or("00");
|
||||
key_bytes[i] = u8::from_str_radix(hex_str, 16).unwrap_or(0);
|
||||
}
|
||||
eprintln!("Encrypted payload decoding enabled (key loaded).");
|
||||
Some(std::sync::Mutex::new(wzp_crypto::ChaChaSession::new(
|
||||
key_bytes,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Replay mode: offline analysis of a .wzp capture file
|
||||
if let Some(ref replay_path) = args.replay {
|
||||
return run_replay(replay_path, &args).await;
|
||||
}
|
||||
|
||||
// Live mode requires relay and room
|
||||
let relay = args.relay.as_deref().ok_or_else(|| {
|
||||
anyhow::anyhow!("relay address required for live mode (use --replay for offline)")
|
||||
})?;
|
||||
let room = args.room.as_deref().ok_or_else(|| {
|
||||
anyhow::anyhow!("--room required for live mode (use --replay for offline)")
|
||||
})?;
|
||||
|
||||
// TLS crypto provider
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Identity seed
|
||||
let seed = match &args.seed {
|
||||
Some(hex) => {
|
||||
let s = wzp_crypto::Seed::from_hex(hex).map_err(|e| anyhow::anyhow!(e))?;
|
||||
info!(fingerprint = %s.derive_identity().public_identity().fingerprint, "identity from --seed");
|
||||
s
|
||||
}
|
||||
None => {
|
||||
let s = wzp_crypto::Seed::generate();
|
||||
info!(fingerprint = %s.derive_identity().public_identity().fingerprint, "generated ephemeral identity");
|
||||
s
|
||||
}
|
||||
};
|
||||
|
||||
// Connect to relay
|
||||
let relay_addr: std::net::SocketAddr = relay.parse()?;
|
||||
let bind_addr: std::net::SocketAddr = if relay_addr.is_ipv6() {
|
||||
"[::]:0".parse()?
|
||||
} else {
|
||||
"0.0.0.0:0".parse()?
|
||||
};
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let client_config = wzp_transport::client_config();
|
||||
let conn = wzp_transport::connect(&endpoint, relay_addr, room, client_config).await?;
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
|
||||
// Crypto handshake
|
||||
let _crypto_session =
|
||||
wzp_client::handshake::perform_handshake(&*transport, &seed.0, Some("analyzer")).await?;
|
||||
|
||||
// Auth if token provided
|
||||
if let Some(ref token) = args.token {
|
||||
let auth = wzp_proto::SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: token.clone(),
|
||||
};
|
||||
transport.send_signal(&auth).await?;
|
||||
}
|
||||
|
||||
// Capture file (optional)
|
||||
let mut capture_writer = args
|
||||
.capture
|
||||
.as_ref()
|
||||
.map(|path| CaptureWriter::new(path, room, relay))
|
||||
.transpose()?;
|
||||
|
||||
// Duration timeout
|
||||
let deadline = args
|
||||
.duration
|
||||
.map(|s| Instant::now() + Duration::from_secs(s));
|
||||
|
||||
// State
|
||||
let mut participants: Vec<ParticipantStats> = Vec::new();
|
||||
let mut total_packets: u64 = 0;
|
||||
let start_time = Instant::now();
|
||||
|
||||
if args.no_tui {
|
||||
run_no_tui(
|
||||
&transport,
|
||||
&mut participants,
|
||||
&mut total_packets,
|
||||
deadline,
|
||||
capture_writer.as_mut(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
run_tui(
|
||||
&transport,
|
||||
&mut participants,
|
||||
&mut total_packets,
|
||||
start_time,
|
||||
deadline,
|
||||
capture_writer.as_mut(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Print summary
|
||||
print_summary(&participants, total_packets, start_time.elapsed());
|
||||
|
||||
// Clean close
|
||||
transport.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -3,19 +3,19 @@
|
||||
//! Both structs use 48 kHz, mono, i16 format to match the WarzonePhone codec
|
||||
//! pipeline. Frames are 960 samples (20 ms at 48 kHz).
|
||||
//!
|
||||
//! The cpal `Stream` type is not `Send`, so each struct spawns a dedicated OS
|
||||
//! thread that owns the stream. The public API exposes only `Send + Sync`
|
||||
//! channel handles.
|
||||
//! Audio callbacks are **lock-free**: they read/write directly to an `AudioRing`
|
||||
//! (atomic SPSC ring buffer). No Mutex, no channel, no allocation on the hot path.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// Number of samples per 20 ms frame at 48 kHz mono.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
@@ -23,22 +23,24 @@ pub const FRAME_SAMPLES: usize = 960;
|
||||
// AudioCapture
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Captures microphone input and yields 960-sample PCM frames.
|
||||
/// Captures microphone input via CPAL and writes PCM into a lock-free ring buffer.
|
||||
///
|
||||
/// The cpal stream lives on a dedicated OS thread; this handle is `Send + Sync`.
|
||||
pub struct AudioCapture {
|
||||
rx: mpsc::Receiver<Vec<i16>>,
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioCapture {
|
||||
/// Create and start capturing from the default input device at 48 kHz mono.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(64);
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let (init_tx, init_rx) = mpsc::sync_channel::<Result<(), String>>(1);
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-capture".into())
|
||||
@@ -59,53 +61,57 @@ impl AudioCapture {
|
||||
|
||||
let use_f32 = !supports_i16_input(&device)?;
|
||||
|
||||
let buf = Arc::new(std::sync::Mutex::new(
|
||||
Vec::<i16>::with_capacity(FRAME_SAMPLES),
|
||||
));
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("input stream error: {e}");
|
||||
};
|
||||
|
||||
let logged_cb_size = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let stream = if use_f32 {
|
||||
let buf = buf.clone();
|
||||
let tx = tx.clone();
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let logged = logged_cb_size.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lock = buf.lock().unwrap();
|
||||
for &s in data {
|
||||
lock.push(f32_to_i16(s));
|
||||
if lock.len() == FRAME_SAMPLES {
|
||||
let frame = lock.drain(..).collect();
|
||||
let _ = tx.try_send(frame);
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!(
|
||||
"[audio] capture callback: {} f32 samples",
|
||||
data.len()
|
||||
);
|
||||
}
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in data.chunks(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
for i in 0..n {
|
||||
tmp[i] = f32_to_i16(chunk[i]);
|
||||
}
|
||||
ring.write(&tmp[..n]);
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let buf = buf.clone();
|
||||
let tx = tx.clone();
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let logged = logged_cb_size.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lock = buf.lock().unwrap();
|
||||
for &s in data {
|
||||
lock.push(s);
|
||||
if lock.len() == FRAME_SAMPLES {
|
||||
let frame = lock.drain(..).collect();
|
||||
let _ = tx.try_send(frame);
|
||||
}
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!(
|
||||
"[audio] capture callback: {} i16 samples",
|
||||
data.len()
|
||||
);
|
||||
}
|
||||
ring.write(data);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
@@ -114,7 +120,6 @@ impl AudioCapture {
|
||||
|
||||
stream.play().context("failed to start input stream")?;
|
||||
|
||||
// Signal success to the caller before parking.
|
||||
let _ = init_tx.send(Ok(()));
|
||||
|
||||
// Keep stream alive until stopped.
|
||||
@@ -135,15 +140,12 @@ impl AudioCapture {
|
||||
.map_err(|_| anyhow!("capture thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { rx, running })
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
/// Read the next frame of 960 PCM samples (blocking until available).
|
||||
///
|
||||
/// Returns `None` when the stream has been stopped or the channel is
|
||||
/// disconnected.
|
||||
pub fn read_frame(&self) -> Option<Vec<i16>> {
|
||||
self.rx.recv().ok()
|
||||
/// Get a reference to the capture ring buffer for direct polling.
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
/// Stop capturing.
|
||||
@@ -152,26 +154,34 @@ impl AudioCapture {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AudioCapture {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AudioPlayback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Plays PCM frames through the default output device at 48 kHz mono.
|
||||
/// Plays PCM through the default output device, reading from a lock-free ring buffer.
|
||||
///
|
||||
/// The cpal stream lives on a dedicated OS thread; this handle is `Send + Sync`.
|
||||
pub struct AudioPlayback {
|
||||
tx: mpsc::SyncSender<Vec<i16>>,
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioPlayback {
|
||||
/// Create and start playback on the default output device at 48 kHz mono.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(64);
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let (init_tx, init_rx) = mpsc::sync_channel::<Result<(), String>>(1);
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-playback".into())
|
||||
@@ -192,62 +202,40 @@ impl AudioPlayback {
|
||||
|
||||
let use_f32 = !supports_i16_output(&device)?;
|
||||
|
||||
// Shared ring of samples the cpal callback drains from.
|
||||
let ring = Arc::new(std::sync::Mutex::new(
|
||||
std::collections::VecDeque::<i16>::with_capacity(FRAME_SAMPLES * 8),
|
||||
));
|
||||
|
||||
// Background drainer: moves frames from the mpsc channel into the ring.
|
||||
{
|
||||
let ring = ring.clone();
|
||||
let running = running_clone.clone();
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-playback-drain".into())
|
||||
.spawn(move || {
|
||||
while running.load(Ordering::Relaxed) {
|
||||
match rx.recv_timeout(std::time::Duration::from_millis(100)) {
|
||||
Ok(frame) => {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
lock.extend(frame);
|
||||
while lock.len() > FRAME_SAMPLES * 16 {
|
||||
lock.pop_front();
|
||||
}
|
||||
}
|
||||
Err(mpsc::RecvTimeoutError::Timeout) => {}
|
||||
Err(mpsc::RecvTimeoutError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("output stream error: {e}");
|
||||
};
|
||||
|
||||
let stream = if use_f32 {
|
||||
let ring = ring.clone();
|
||||
let ring = ring_cb.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
for sample in data.iter_mut() {
|
||||
*sample = match lock.pop_front() {
|
||||
Some(s) => i16_to_f32(s),
|
||||
None => 0.0,
|
||||
};
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in data.chunks_mut(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
let read = ring.read(&mut tmp[..n]);
|
||||
for i in 0..read {
|
||||
chunk[i] = i16_to_f32(tmp[i]);
|
||||
}
|
||||
// Fill remainder with silence if ring underran
|
||||
for i in read..n {
|
||||
chunk[i] = 0.0;
|
||||
}
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let ring = ring.clone();
|
||||
let ring = ring_cb.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [i16], _: &cpal::OutputCallbackInfo| {
|
||||
let mut lock = ring.lock().unwrap();
|
||||
for sample in data.iter_mut() {
|
||||
*sample = lock.pop_front().unwrap_or(0);
|
||||
let read = ring.read(data);
|
||||
// Fill remainder with silence if ring underran
|
||||
for sample in &mut data[read..] {
|
||||
*sample = 0;
|
||||
}
|
||||
},
|
||||
err_cb,
|
||||
@@ -257,7 +245,6 @@ impl AudioPlayback {
|
||||
|
||||
stream.play().context("failed to start output stream")?;
|
||||
|
||||
// Signal success to the caller before parking.
|
||||
let _ = init_tx.send(Ok(()));
|
||||
|
||||
// Keep stream alive until stopped.
|
||||
@@ -278,12 +265,12 @@ impl AudioPlayback {
|
||||
.map_err(|_| anyhow!("playback thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { tx, running })
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
/// Write a frame of PCM samples for playback.
|
||||
pub fn write_frame(&self, pcm: &[i16]) {
|
||||
let _ = self.tx.try_send(pcm.to_vec());
|
||||
/// Get a reference to the playout ring buffer for direct writing.
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
/// Stop playback.
|
||||
@@ -292,11 +279,16 @@ impl AudioPlayback {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AudioPlayback {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Check if the input device supports i16 at 48 kHz mono.
|
||||
fn supports_i16_input(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_input_configs()
|
||||
@@ -313,7 +305,6 @@ fn supports_i16_input(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Check if the output device supports i16 at 48 kHz mono.
|
||||
fn supports_i16_output(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_output_configs()
|
||||
|
||||
537
crates/wzp-client/src/audio_linux_aec.rs
Normal file
537
crates/wzp-client/src/audio_linux_aec.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
//! Linux AEC backend: CPAL capture + playback wired through the WebRTC Audio
|
||||
//! Processing Module (AEC3 + noise suppression + high-pass filter).
|
||||
//!
|
||||
//! This is the same algorithm used by Chrome WebRTC, Zoom, Teams, Jitsi, and
|
||||
//! any other "serious" Linux VoIP app. It runs in-process — no dependency on
|
||||
//! PulseAudio's module-echo-cancel or PipeWire's filter-chain, so it works
|
||||
//! identically on ALSA / PulseAudio / PipeWire systems.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! A single module-level `Arc<Mutex<Processor>>` is shared between the
|
||||
//! capture and playback paths. On each 20 ms frame (960 samples @ 48 kHz
|
||||
//! mono):
|
||||
//!
|
||||
//! - **Playback path**: `LinuxAecPlayback::start` spawns the usual CPAL
|
||||
//! output thread, but wraps each chunk in a call to
|
||||
//! `Processor::process_render_frame` **before** handing it to CPAL. That
|
||||
//! gives APM an authoritative reference of exactly what's going out to
|
||||
//! the speakers (same approach Zoom/Teams/Jitsi use). The AEC then knows
|
||||
//! what to cancel when it sees echo in the capture stream.
|
||||
//!
|
||||
//! - **Capture path**: `LinuxAecCapture::start` spawns the usual CPAL
|
||||
//! input thread, and runs `Processor::process_capture_frame` on each
|
||||
//! incoming mic chunk **in place** before pushing it into the ring
|
||||
//! buffer. The AEC subtracts the echo using the render reference it
|
||||
//! saw on the playback side.
|
||||
//!
|
||||
//! APM is strict about frame size: it requires exactly 10 ms = 480 samples
|
||||
//! per call at 48 kHz. Our pipeline uses 20 ms = 960 samples, so each 20 ms
|
||||
//! frame is split into two 480-sample halves, APM is called twice, and the
|
||||
//! halves are stitched back together.
|
||||
//!
|
||||
//! APM only accepts f32 samples in `[-1.0, 1.0]`, so we convert i16 → f32
|
||||
//! before the call and f32 → i16 after (with clamping on the return path).
|
||||
//!
|
||||
//! ## Stream delay
|
||||
//!
|
||||
//! AEC needs to know roughly how long it takes between a sample being passed
|
||||
//! to `process_render_frame` and its echo showing up at `process_capture_frame`
|
||||
//! — i.e. the round trip through CPAL playback → speaker → air → microphone
|
||||
//! → CPAL capture. AEC3's internal estimator tracks this within a window
|
||||
//! around whatever hint we give it. We hardcode 60 ms as a reasonable
|
||||
//! starting point for typical Linux audio stacks; the delay estimator does
|
||||
//! the fine-tuning automatically.
|
||||
//!
|
||||
//! ## Thread safety
|
||||
//!
|
||||
//! The 0.3.x line of `webrtc-audio-processing` takes `&mut self` on both
|
||||
//! `process_capture_frame` and `process_render_frame`, so the `Processor`
|
||||
//! needs a `Mutex` around it for cross-thread sharing. The capture and
|
||||
//! playback threads each acquire the lock briefly (sub-millisecond per
|
||||
//! 10 ms frame) so contention is minimal at our frame rates.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleFormat, SampleRate, StreamConfig};
|
||||
use tracing::{info, warn};
|
||||
use webrtc_audio_processing::{
|
||||
Config, EchoCancellation, EchoCancellationSuppressionLevel, InitializationConfig,
|
||||
NUM_SAMPLES_PER_FRAME, NoiseSuppression, NoiseSuppressionLevel, Processor,
|
||||
};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// 20 ms at 48 kHz, mono — matches the rest of the pipeline and the codec.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
/// APM requires strict 10 ms frames at 48 kHz = 480 samples per call.
|
||||
/// Imported from the webrtc-audio-processing crate so we can't drift out
|
||||
/// of sync with whatever sample rate / frame length the C++ lib is using.
|
||||
const APM_FRAME_SAMPLES: usize = NUM_SAMPLES_PER_FRAME as usize;
|
||||
const APM_NUM_CHANNELS: usize = 1;
|
||||
/// Round-trip delay hint passed to APM; the estimator refines from here.
|
||||
/// 60 ms is a reasonable default for CPAL on ALSA / PulseAudio / PipeWire.
|
||||
#[allow(dead_code)]
|
||||
const STREAM_DELAY_MS: i32 = 60;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared APM instance
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Module-level lazily-initialized APM. Shared between capture and playback
|
||||
/// so they operate on the same echo-cancellation state — the render frames
|
||||
/// pushed by playback are what the capture path subtracts from the mic input.
|
||||
/// Wrapped in a Mutex because the 0.3.x Processor takes `&mut self` on both
|
||||
/// process_capture_frame and process_render_frame.
|
||||
static PROCESSOR: OnceLock<Arc<Mutex<Processor>>> = OnceLock::new();
|
||||
|
||||
fn get_or_init_processor() -> anyhow::Result<Arc<Mutex<Processor>>> {
|
||||
if let Some(p) = PROCESSOR.get() {
|
||||
return Ok(p.clone());
|
||||
}
|
||||
let init_config = InitializationConfig {
|
||||
num_capture_channels: APM_NUM_CHANNELS as i32,
|
||||
num_render_channels: APM_NUM_CHANNELS as i32,
|
||||
..Default::default()
|
||||
};
|
||||
let mut processor =
|
||||
Processor::new(&init_config).map_err(|e| anyhow!("webrtc APM init failed: {e:?}"))?;
|
||||
|
||||
let config = Config {
|
||||
echo_cancellation: Some(EchoCancellation {
|
||||
suppression_level: EchoCancellationSuppressionLevel::High,
|
||||
stream_delay_ms: Some(STREAM_DELAY_MS),
|
||||
enable_delay_agnostic: true,
|
||||
enable_extended_filter: true,
|
||||
}),
|
||||
noise_suppression: Some(NoiseSuppression {
|
||||
suppression_level: NoiseSuppressionLevel::High,
|
||||
}),
|
||||
enable_high_pass_filter: true,
|
||||
// AGC left off for now — it can fight the Opus encoder's own gain
|
||||
// staging and the adaptive-quality controller. Add later if users
|
||||
// report low mic levels.
|
||||
..Default::default()
|
||||
};
|
||||
processor.set_config(config);
|
||||
|
||||
let arc = Arc::new(Mutex::new(processor));
|
||||
let _ = PROCESSOR.set(arc.clone());
|
||||
info!(
|
||||
stream_delay_ms = STREAM_DELAY_MS,
|
||||
"webrtc APM initialized (AEC High + NS High + HPF, AGC off)"
|
||||
);
|
||||
Ok(arc)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers: i16 ↔ f32 and APM frame processing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[inline]
|
||||
fn i16_to_f32(s: i16) -> f32 {
|
||||
s as f32 / 32768.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn f32_to_i16(s: f32) -> i16 {
|
||||
(s.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
|
||||
/// Feed a 20 ms (960-sample) playback frame to APM as the render reference.
|
||||
/// Splits into two 10 ms halves because APM is strict about frame size.
|
||||
/// Takes the Mutex-wrapped Processor and locks briefly around each call.
|
||||
fn push_render_frame_20ms(apm: &Mutex<Processor>, pcm: &[i16]) {
|
||||
debug_assert_eq!(pcm.len(), FRAME_SAMPLES);
|
||||
let mut buf = [0f32; APM_FRAME_SAMPLES];
|
||||
for half in pcm.chunks_exact(APM_FRAME_SAMPLES) {
|
||||
for (i, &s) in half.iter().enumerate() {
|
||||
buf[i] = i16_to_f32(s);
|
||||
}
|
||||
match apm.lock() {
|
||||
Ok(mut p) => {
|
||||
if let Err(e) = p.process_render_frame(&mut buf) {
|
||||
warn!("webrtc APM process_render_frame failed: {e:?}");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("webrtc APM mutex poisoned in render path");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a 20 ms (960-sample) capture frame through APM's echo cancellation
|
||||
/// in place. Splits into two 10 ms halves, runs APM on each, stitches
|
||||
/// results back into the caller's buffer. Briefly holds the Mutex once
|
||||
/// per 10 ms half.
|
||||
fn process_capture_frame_20ms(apm: &Mutex<Processor>, pcm: &mut [i16]) {
|
||||
debug_assert_eq!(pcm.len(), FRAME_SAMPLES);
|
||||
let mut buf = [0f32; APM_FRAME_SAMPLES];
|
||||
for half in pcm.chunks_exact_mut(APM_FRAME_SAMPLES) {
|
||||
for (i, &s) in half.iter().enumerate() {
|
||||
buf[i] = i16_to_f32(s);
|
||||
}
|
||||
match apm.lock() {
|
||||
Ok(mut p) => {
|
||||
if let Err(e) = p.process_capture_frame(&mut buf) {
|
||||
warn!("webrtc APM process_capture_frame failed: {e:?}");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("webrtc APM mutex poisoned in capture path");
|
||||
return;
|
||||
}
|
||||
}
|
||||
for (i, d) in half.iter_mut().enumerate() {
|
||||
*d = f32_to_i16(buf[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LinuxAecCapture — CPAL mic + WebRTC AEC capture-side processing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Microphone capture with WebRTC AEC3 applied in place before the codec
|
||||
/// sees the samples. Mirrors the public API of `audio_io::AudioCapture` so
|
||||
/// downstream code doesn't change.
|
||||
pub struct LinuxAecCapture {
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl LinuxAecCapture {
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
// Eagerly init the APM so the playback side can find it already
|
||||
// configured, and so init errors surface on the caller thread
|
||||
// instead of silently failing inside the capture thread.
|
||||
let apm = get_or_init_processor()?;
|
||||
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
let apm_capture = apm.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-capture-linuxaec".into())
|
||||
.spawn(move || {
|
||||
let result = (|| -> Result<(), anyhow::Error> {
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.ok_or_else(|| anyhow!("no default input audio device found"))?;
|
||||
info!(device = %device.name().unwrap_or_default(), "LinuxAEC: using input device");
|
||||
|
||||
let config = StreamConfig {
|
||||
channels: 1,
|
||||
sample_rate: SampleRate(48_000),
|
||||
buffer_size: cpal::BufferSize::Default,
|
||||
};
|
||||
|
||||
let use_f32 = !supports_i16_input(&device)?;
|
||||
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("LinuxAEC input stream error: {e}");
|
||||
};
|
||||
|
||||
// Leftover buffer for when CPAL gives us partial frames.
|
||||
// We need exactly 960-sample chunks to feed APM.
|
||||
let leftover = std::sync::Mutex::new(Vec::<i16>::with_capacity(FRAME_SAMPLES * 4));
|
||||
|
||||
let stream = if use_f32 {
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let apm = apm_capture.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lv = leftover.lock().unwrap();
|
||||
lv.reserve(data.len());
|
||||
for &s in data {
|
||||
lv.push(f32_to_i16(s));
|
||||
}
|
||||
drain_frames_through_apm(&mut lv, &apm, &ring);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let ring = ring_cb.clone();
|
||||
let running = running_clone.clone();
|
||||
let apm = apm_capture.clone();
|
||||
device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
||||
if !running.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
let mut lv = leftover.lock().unwrap();
|
||||
lv.extend_from_slice(data);
|
||||
drain_frames_through_apm(&mut lv, &apm, &ring);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
};
|
||||
|
||||
stream.play().context("failed to start LinuxAEC input stream")?;
|
||||
let _ = init_tx.send(Ok(()));
|
||||
info!("LinuxAEC capture started (AEC3 active)");
|
||||
|
||||
while running_clone.load(Ordering::Relaxed) {
|
||||
std::thread::park_timeout(std::time::Duration::from_millis(200));
|
||||
}
|
||||
drop(stream);
|
||||
Ok(())
|
||||
})();
|
||||
|
||||
if let Err(e) = result {
|
||||
let _ = init_tx.send(Err(e.to_string()));
|
||||
}
|
||||
})?;
|
||||
|
||||
init_rx
|
||||
.recv()
|
||||
.map_err(|_| anyhow!("LinuxAEC capture thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LinuxAecCapture {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull whole 960-sample frames out of the leftover buffer, run them through
|
||||
/// APM's capture-side processing, and push to the ring. Leaves any partial
|
||||
/// sub-960 remainder in `leftover` for the next callback.
|
||||
fn drain_frames_through_apm(leftover: &mut Vec<i16>, apm: &Mutex<Processor>, ring: &AudioRing) {
|
||||
let mut frame = [0i16; FRAME_SAMPLES];
|
||||
while leftover.len() >= FRAME_SAMPLES {
|
||||
frame.copy_from_slice(&leftover[..FRAME_SAMPLES]);
|
||||
process_capture_frame_20ms(apm, &mut frame);
|
||||
ring.write(&frame);
|
||||
leftover.drain(..FRAME_SAMPLES);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LinuxAecPlayback — CPAL speaker output + WebRTC AEC render-side tee
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Speaker playback with a render-side tee: each frame written to CPAL is
|
||||
/// ALSO fed to APM via `process_render_frame` as the echo-cancellation
|
||||
/// reference signal. This is the "tee the playback ring" approach (Zoom,
|
||||
/// Teams, Jitsi) — deterministic, does not depend on PulseAudio loopback or
|
||||
/// PipeWire monitor sources.
|
||||
pub struct LinuxAecPlayback {
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl LinuxAecPlayback {
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let apm = get_or_init_processor()?;
|
||||
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
|
||||
let ring_cb = ring.clone();
|
||||
let running_clone = running.clone();
|
||||
let apm_render = apm.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("wzp-audio-playback-linuxaec".into())
|
||||
.spawn(move || {
|
||||
let result = (|| -> Result<(), anyhow::Error> {
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.ok_or_else(|| anyhow!("no default output audio device found"))?;
|
||||
info!(device = %device.name().unwrap_or_default(), "LinuxAEC: using output device");
|
||||
|
||||
let config = StreamConfig {
|
||||
channels: 1,
|
||||
sample_rate: SampleRate(48_000),
|
||||
buffer_size: cpal::BufferSize::Default,
|
||||
};
|
||||
|
||||
let use_f32 = !supports_i16_output(&device)?;
|
||||
|
||||
let err_cb = |e: cpal::StreamError| {
|
||||
warn!("LinuxAEC output stream error: {e}");
|
||||
};
|
||||
|
||||
// Same 960-sample batching approach as the capture side:
|
||||
// CPAL may ask for N samples in a callback where N doesn't
|
||||
// divide 960. We accumulate partial frames in a Vec and
|
||||
// feed APM as soon as we have a whole 20 ms frame.
|
||||
let carry = std::sync::Mutex::new(Vec::<i16>::with_capacity(FRAME_SAMPLES * 4));
|
||||
|
||||
let stream = if use_f32 {
|
||||
let ring = ring_cb.clone();
|
||||
let apm = apm_render.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
fill_output_and_tee_f32(data, &ring, &apm, &carry);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
} else {
|
||||
let ring = ring_cb.clone();
|
||||
let apm = apm_render.clone();
|
||||
device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [i16], _: &cpal::OutputCallbackInfo| {
|
||||
fill_output_and_tee_i16(data, &ring, &apm, &carry);
|
||||
},
|
||||
err_cb,
|
||||
None,
|
||||
)?
|
||||
};
|
||||
|
||||
stream.play().context("failed to start LinuxAEC output stream")?;
|
||||
let _ = init_tx.send(Ok(()));
|
||||
info!("LinuxAEC playback started (render tee active)");
|
||||
|
||||
while running_clone.load(Ordering::Relaxed) {
|
||||
std::thread::park_timeout(std::time::Duration::from_millis(200));
|
||||
}
|
||||
drop(stream);
|
||||
Ok(())
|
||||
})();
|
||||
|
||||
if let Err(e) = result {
|
||||
let _ = init_tx.send(Err(e.to_string()));
|
||||
}
|
||||
})?;
|
||||
|
||||
init_rx
|
||||
.recv()
|
||||
.map_err(|_| anyhow!("LinuxAEC playback thread exited before signaling"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self { ring, running })
|
||||
}
|
||||
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LinuxAecPlayback {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_output_and_tee_i16(
|
||||
data: &mut [i16],
|
||||
ring: &AudioRing,
|
||||
apm: &Mutex<Processor>,
|
||||
carry: &std::sync::Mutex<Vec<i16>>,
|
||||
) {
|
||||
let read = ring.read(data);
|
||||
for s in &mut data[read..] {
|
||||
*s = 0;
|
||||
}
|
||||
tee_render_samples(data, apm, carry);
|
||||
}
|
||||
|
||||
fn fill_output_and_tee_f32(
|
||||
data: &mut [f32],
|
||||
ring: &AudioRing,
|
||||
apm: &Mutex<Processor>,
|
||||
carry: &std::sync::Mutex<Vec<i16>>,
|
||||
) {
|
||||
let mut tmp = vec![0i16; data.len()];
|
||||
let read = ring.read(&mut tmp);
|
||||
for s in &mut tmp[read..] {
|
||||
*s = 0;
|
||||
}
|
||||
for (d, &s) in data.iter_mut().zip(tmp.iter()) {
|
||||
*d = i16_to_f32(s);
|
||||
}
|
||||
tee_render_samples(&tmp, apm, carry);
|
||||
}
|
||||
|
||||
/// Push CPAL-bound samples into APM's render-side input for echo cancellation.
|
||||
/// Uses a carry buffer to batch into exact 960-sample (20 ms) frames.
|
||||
fn tee_render_samples(samples: &[i16], apm: &Mutex<Processor>, carry: &std::sync::Mutex<Vec<i16>>) {
|
||||
let mut lv = carry.lock().unwrap();
|
||||
lv.extend_from_slice(samples);
|
||||
while lv.len() >= FRAME_SAMPLES {
|
||||
let mut frame = [0i16; FRAME_SAMPLES];
|
||||
frame.copy_from_slice(&lv[..FRAME_SAMPLES]);
|
||||
push_render_frame_20ms(apm, &frame);
|
||||
lv.drain(..FRAME_SAMPLES);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CPAL format helpers (duplicated from audio_io.rs to keep the modules
|
||||
// independent — each backend file is a self-contained unit)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn supports_i16_input(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_input_configs()
|
||||
.context("failed to query input configs")?;
|
||||
for cfg in supported {
|
||||
if cfg.sample_format() == SampleFormat::I16
|
||||
&& cfg.min_sample_rate() <= SampleRate(48_000)
|
||||
&& cfg.max_sample_rate() >= SampleRate(48_000)
|
||||
&& cfg.channels() >= 1
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn supports_i16_output(device: &cpal::Device) -> Result<bool, anyhow::Error> {
|
||||
let supported = device
|
||||
.supported_output_configs()
|
||||
.context("failed to query output configs")?;
|
||||
for cfg in supported {
|
||||
if cfg.sample_format() == SampleFormat::I16
|
||||
&& cfg.min_sample_rate() <= SampleRate(48_000)
|
||||
&& cfg.max_sample_rate() >= SampleRate(48_000)
|
||||
&& cfg.channels() >= 1
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
122
crates/wzp-client/src/audio_ring.rs
Normal file
122
crates/wzp-client/src/audio_ring.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
//! Lock-free SPSC ring buffer — "Reader-Detects-Lap" architecture.
|
||||
//!
|
||||
//! SPSC invariant: the producer ONLY writes `write_pos`, the consumer
|
||||
//! ONLY writes `read_pos`. Neither thread touches the other's cursor.
|
||||
//!
|
||||
//! On overflow (writer laps the reader), the writer simply overwrites
|
||||
//! old buffer data. The reader detects the lap via `available() >
|
||||
//! RING_CAPACITY` and snaps its own `read_pos` forward.
|
||||
//!
|
||||
//! Capacity is a power of 2 for bitmask indexing (no modulo).
|
||||
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
|
||||
/// Ring buffer capacity — power of 2 for bitmask indexing.
|
||||
/// 16384 samples = 341.3ms at 48kHz mono.
|
||||
const RING_CAPACITY: usize = 16384; // 2^14
|
||||
const RING_MASK: usize = RING_CAPACITY - 1;
|
||||
|
||||
/// Lock-free single-producer single-consumer ring buffer for i16 PCM samples.
|
||||
pub struct AudioRing {
|
||||
buf: Box<[i16]>,
|
||||
/// Monotonically increasing write cursor. ONLY written by producer.
|
||||
write_pos: AtomicUsize,
|
||||
/// Monotonically increasing read cursor. ONLY written by consumer.
|
||||
read_pos: AtomicUsize,
|
||||
/// Incremented by reader when it detects it was lapped (overflow).
|
||||
overflow_count: AtomicU64,
|
||||
/// Incremented by reader when ring is empty (underrun).
|
||||
underrun_count: AtomicU64,
|
||||
}
|
||||
|
||||
// SAFETY: AudioRing is SPSC — one thread writes (producer), one reads (consumer).
|
||||
// The producer only writes write_pos. The consumer only writes read_pos.
|
||||
// Neither thread writes the other's cursor. Buffer indices are derived from
|
||||
// the owning thread's cursor, ensuring no concurrent access to the same index.
|
||||
unsafe impl Send for AudioRing {}
|
||||
unsafe impl Sync for AudioRing {}
|
||||
|
||||
impl AudioRing {
|
||||
pub fn new() -> Self {
|
||||
debug_assert!(RING_CAPACITY.is_power_of_two());
|
||||
Self {
|
||||
buf: vec![0i16; RING_CAPACITY].into_boxed_slice(),
|
||||
write_pos: AtomicUsize::new(0),
|
||||
read_pos: AtomicUsize::new(0),
|
||||
overflow_count: AtomicU64::new(0),
|
||||
underrun_count: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of samples available to read (clamped to capacity).
|
||||
pub fn available(&self) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let r = self.read_pos.load(Ordering::Relaxed);
|
||||
w.wrapping_sub(r).min(RING_CAPACITY)
|
||||
}
|
||||
|
||||
/// Write samples into the ring. Returns number of samples written.
|
||||
///
|
||||
/// If the ring is full, old data is silently overwritten. The reader
|
||||
/// will detect the lap and self-correct. The writer NEVER touches
|
||||
/// `read_pos`.
|
||||
pub fn write(&self, samples: &[i16]) -> usize {
|
||||
let count = samples.len().min(RING_CAPACITY);
|
||||
let w = self.write_pos.load(Ordering::Relaxed);
|
||||
|
||||
for i in 0..count {
|
||||
unsafe {
|
||||
let ptr = self.buf.as_ptr() as *mut i16;
|
||||
*ptr.add((w + i) & RING_MASK) = samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
self.write_pos
|
||||
.store(w.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Read samples from the ring into `out`. Returns number of samples read.
|
||||
///
|
||||
/// If the writer has lapped the reader (overflow), `read_pos` is snapped
|
||||
/// forward to the oldest valid data.
|
||||
pub fn read(&self, out: &mut [i16]) -> usize {
|
||||
let w = self.write_pos.load(Ordering::Acquire);
|
||||
let mut r = self.read_pos.load(Ordering::Relaxed);
|
||||
|
||||
let mut avail = w.wrapping_sub(r);
|
||||
|
||||
// Lap detection: writer has overwritten our unread data.
|
||||
if avail > RING_CAPACITY {
|
||||
r = w.wrapping_sub(RING_CAPACITY);
|
||||
avail = RING_CAPACITY;
|
||||
self.overflow_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let count = out.len().min(avail);
|
||||
if count == 0 {
|
||||
if w == r {
|
||||
self.underrun_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
for i in 0..count {
|
||||
out[i] = unsafe { *self.buf.as_ptr().add((r + i) & RING_MASK) };
|
||||
}
|
||||
|
||||
self.read_pos
|
||||
.store(r.wrapping_add(count), Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
/// Number of overflow events (reader was lapped by writer).
|
||||
pub fn overflow_count(&self) -> u64 {
|
||||
self.overflow_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Number of underrun events (reader found empty buffer).
|
||||
pub fn underrun_count(&self) -> u64 {
|
||||
self.underrun_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
180
crates/wzp-client/src/audio_vpio.rs
Normal file
180
crates/wzp-client/src/audio_vpio.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! macOS Voice Processing I/O — uses Apple's VoiceProcessingIO audio unit
|
||||
//! for hardware-accelerated echo cancellation, AGC, and noise suppression.
|
||||
//!
|
||||
//! VoiceProcessingIO is a combined input+output unit that knows what's going
|
||||
//! to the speaker, so it can cancel the echo from the mic signal internally.
|
||||
//! This is the same engine FaceTime and other Apple apps use.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::Context;
|
||||
use coreaudio::audio_unit::audio_format::LinearPcmFlags;
|
||||
use coreaudio::audio_unit::render_callback::{self, data};
|
||||
use coreaudio::audio_unit::{AudioUnit, Element, IOType, SampleFormat, Scope, StreamFormat};
|
||||
use coreaudio::sys;
|
||||
use tracing::info;
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// Number of samples per 20 ms frame at 48 kHz mono.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
/// Combined capture + playback via macOS VoiceProcessingIO.
|
||||
///
|
||||
/// The OS handles AEC internally — no manual far-end feeding needed.
|
||||
pub struct VpioAudio {
|
||||
capture_ring: Arc<AudioRing>,
|
||||
playout_ring: Arc<AudioRing>,
|
||||
_audio_unit: AudioUnit,
|
||||
running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl VpioAudio {
|
||||
/// Start VoiceProcessingIO with AEC enabled.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let capture_ring = Arc::new(AudioRing::new());
|
||||
let playout_ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let mut au = AudioUnit::new(IOType::VoiceProcessingIO)
|
||||
.context("failed to create VoiceProcessingIO audio unit")?;
|
||||
|
||||
// Must uninitialize before configuring properties.
|
||||
au.uninitialize()
|
||||
.context("failed to uninitialize VPIO for configuration")?;
|
||||
|
||||
// Enable input (mic) on Element::Input (bus 1).
|
||||
let enable: u32 = 1;
|
||||
au.set_property(
|
||||
sys::kAudioOutputUnitProperty_EnableIO,
|
||||
Scope::Input,
|
||||
Element::Input,
|
||||
Some(&enable),
|
||||
)
|
||||
.context("failed to enable VPIO input")?;
|
||||
|
||||
// Output (speaker) is enabled by default on VPIO, but be explicit.
|
||||
au.set_property(
|
||||
sys::kAudioOutputUnitProperty_EnableIO,
|
||||
Scope::Output,
|
||||
Element::Output,
|
||||
Some(&enable),
|
||||
)
|
||||
.context("failed to enable VPIO output")?;
|
||||
|
||||
// Configure stream format: 48kHz mono f32 non-interleaved
|
||||
let stream_format = StreamFormat {
|
||||
sample_rate: 48_000.0,
|
||||
sample_format: SampleFormat::F32,
|
||||
flags: LinearPcmFlags::IS_FLOAT
|
||||
| LinearPcmFlags::IS_PACKED
|
||||
| LinearPcmFlags::IS_NON_INTERLEAVED,
|
||||
channels: 1,
|
||||
};
|
||||
|
||||
let asbd = stream_format.to_asbd();
|
||||
|
||||
// Input: set format on Output scope of Input element
|
||||
// (= the format the AU delivers to us from the mic)
|
||||
au.set_property(
|
||||
sys::kAudioUnitProperty_StreamFormat,
|
||||
Scope::Output,
|
||||
Element::Input,
|
||||
Some(&asbd),
|
||||
)
|
||||
.context("failed to set input stream format")?;
|
||||
|
||||
// Output: set format on Input scope of Output element
|
||||
// (= the format we feed to the AU for the speaker)
|
||||
au.set_property(
|
||||
sys::kAudioUnitProperty_StreamFormat,
|
||||
Scope::Input,
|
||||
Element::Output,
|
||||
Some(&asbd),
|
||||
)
|
||||
.context("failed to set output stream format")?;
|
||||
|
||||
// Set up input callback (mic capture with AEC applied)
|
||||
let cap_ring = capture_ring.clone();
|
||||
let cap_running = running.clone();
|
||||
let logged = Arc::new(AtomicBool::new(false));
|
||||
au.set_input_callback(
|
||||
move |args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
if !cap_running.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
let mut buffers = args.data.channels();
|
||||
if let Some(ch) = buffers.next() {
|
||||
if !logged.swap(true, Ordering::Relaxed) {
|
||||
eprintln!("[vpio] capture callback: {} f32 samples", ch.len());
|
||||
}
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in ch.chunks(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
for i in 0..n {
|
||||
tmp[i] = (chunk[i].clamp(-1.0, 1.0) * i16::MAX as f32) as i16;
|
||||
}
|
||||
cap_ring.write(&tmp[..n]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.context("failed to set input callback")?;
|
||||
|
||||
// Set up output callback (speaker playback — AEC uses this as reference)
|
||||
let play_ring = playout_ring.clone();
|
||||
au.set_render_callback(
|
||||
move |mut args: render_callback::Args<data::NonInterleaved<f32>>| {
|
||||
let mut buffers = args.data.channels_mut();
|
||||
if let Some(ch) = buffers.next() {
|
||||
let mut tmp = [0i16; FRAME_SAMPLES];
|
||||
for chunk in ch.chunks_mut(FRAME_SAMPLES) {
|
||||
let n = chunk.len();
|
||||
let read = play_ring.read(&mut tmp[..n]);
|
||||
for i in 0..read {
|
||||
chunk[i] = tmp[i] as f32 / i16::MAX as f32;
|
||||
}
|
||||
for i in read..n {
|
||||
chunk[i] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.context("failed to set render callback")?;
|
||||
|
||||
au.initialize()
|
||||
.context("failed to initialize VoiceProcessingIO")?;
|
||||
au.start().context("failed to start VoiceProcessingIO")?;
|
||||
|
||||
info!("VoiceProcessingIO started (OS-level AEC enabled)");
|
||||
|
||||
Ok(Self {
|
||||
capture_ring,
|
||||
playout_ring,
|
||||
_audio_unit: au,
|
||||
running,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn capture_ring(&self) -> &Arc<AudioRing> {
|
||||
&self.capture_ring
|
||||
}
|
||||
|
||||
pub fn playout_ring(&self) -> &Arc<AudioRing> {
|
||||
&self.playout_ring
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for VpioAudio {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
330
crates/wzp-client/src/audio_wasapi.rs
Normal file
330
crates/wzp-client/src/audio_wasapi.rs
Normal file
@@ -0,0 +1,330 @@
|
||||
//! Direct WASAPI microphone capture with Windows's OS-level AEC enabled.
|
||||
//!
|
||||
//! Bypasses CPAL and opens the default capture endpoint directly via
|
||||
//! `IMMDeviceEnumerator` + `IAudioClient2::SetClientProperties`, setting
|
||||
//! `AudioClientProperties.eCategory = AudioCategory_Communications`. That's
|
||||
//! the switch that tells Windows "this is a VoIP call" — the OS then
|
||||
//! enables its communications audio processing chain (AEC, noise
|
||||
//! suppression, automatic gain control) for the stream. AEC operates at
|
||||
//! the OS level using the currently-playing audio as the reference
|
||||
//! signal, so it cancels echo from our CPAL playback (and any other app's
|
||||
//! audio) without us having to plumb a reference signal ourselves.
|
||||
//!
|
||||
//! Platform: Windows only, compiled only when the `windows-aec` feature
|
||||
//! is enabled. Mirrors the public API of `audio_io::AudioCapture` so
|
||||
//! `wzp-client`'s lib.rs can transparently re-export either one as
|
||||
//! `AudioCapture`.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use tracing::{info, warn};
|
||||
use windows::Win32::Foundation::{BOOL, CloseHandle, WAIT_OBJECT_0};
|
||||
use windows::Win32::Media::Audio::{
|
||||
AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM,
|
||||
AUDCLNT_STREAMFLAGS_EVENTCALLBACK, AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY,
|
||||
AudioCategory_Communications, AudioClientProperties, IAudioCaptureClient, IAudioClient,
|
||||
IAudioClient2, IMMDeviceEnumerator, MMDeviceEnumerator, WAVE_FORMAT_PCM, WAVEFORMATEX,
|
||||
eCapture, eCommunications,
|
||||
};
|
||||
use windows::Win32::System::Com::{
|
||||
CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
|
||||
};
|
||||
use windows::Win32::System::Threading::{CreateEventW, INFINITE, WaitForSingleObject};
|
||||
use windows::core::{GUID, Interface};
|
||||
|
||||
use crate::audio_ring::AudioRing;
|
||||
|
||||
/// 20 ms at 48 kHz, mono. Matches the rest of the audio pipeline.
|
||||
pub const FRAME_SAMPLES: usize = 960;
|
||||
|
||||
/// Microphone capture via WASAPI with Windows's communications AEC enabled.
|
||||
///
|
||||
/// The WASAPI capture stream runs on a dedicated OS thread. This handle is
|
||||
/// `Send + Sync`. Dropping it stops the stream and joins the thread.
|
||||
pub struct WasapiAudioCapture {
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
thread: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl WasapiAudioCapture {
|
||||
/// Open the default communications microphone, enable OS AEC, and start
|
||||
/// streaming PCM into a lock-free ring buffer.
|
||||
///
|
||||
/// Returns only after the capture thread has successfully initialized
|
||||
/// the stream, or propagates the error back to the caller.
|
||||
pub fn start() -> Result<Self, anyhow::Error> {
|
||||
let ring = Arc::new(AudioRing::new());
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let (init_tx, init_rx) = std::sync::mpsc::sync_channel::<Result<(), String>>(1);
|
||||
let ring_cb = ring.clone();
|
||||
let running_cb = running.clone();
|
||||
|
||||
let thread = std::thread::Builder::new()
|
||||
.name("wzp-audio-capture-wasapi".into())
|
||||
.spawn(move || {
|
||||
let result = unsafe { capture_thread_main(ring_cb, running_cb.clone(), &init_tx) };
|
||||
if let Err(e) = result {
|
||||
warn!("wasapi capture thread exited with error: {e}");
|
||||
// If we failed before signaling init, signal now so the
|
||||
// caller unblocks. Double-send is harmless (channel is
|
||||
// bounded to 1 and we only hit the second send path on
|
||||
// late errors).
|
||||
let _ = init_tx.send(Err(e.to_string()));
|
||||
}
|
||||
})
|
||||
.context("failed to spawn WASAPI capture thread")?;
|
||||
|
||||
init_rx
|
||||
.recv()
|
||||
.map_err(|_| anyhow!("WASAPI capture thread exited before signaling init"))?
|
||||
.map_err(|e| anyhow!("{e}"))?;
|
||||
|
||||
Ok(Self {
|
||||
ring,
|
||||
running,
|
||||
thread: Some(thread),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the capture ring buffer for direct polling.
|
||||
pub fn ring(&self) -> &Arc<AudioRing> {
|
||||
&self.ring
|
||||
}
|
||||
|
||||
/// Stop capturing.
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WasapiAudioCapture {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
if let Some(handle) = self.thread.take() {
|
||||
// Join best-effort. The thread loop polls `running` every 200ms
|
||||
// via a short WaitForSingleObject timeout, so it should exit
|
||||
// within ~200ms of `stop()`.
|
||||
let _ = handle.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WASAPI thread entry point — everything below this line runs on the
|
||||
// dedicated wzp-audio-capture-wasapi thread.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
unsafe fn capture_thread_main(
|
||||
ring: Arc<AudioRing>,
|
||||
running: Arc<AtomicBool>,
|
||||
init_tx: &std::sync::mpsc::SyncSender<Result<(), String>>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// COM init for the capture thread. MULTITHREADED because we're not
|
||||
// running a message pump. Must be balanced by CoUninitialize on exit.
|
||||
CoInitializeEx(None, COINIT_MULTITHREADED)
|
||||
.ok()
|
||||
.context("CoInitializeEx failed")?;
|
||||
|
||||
// Use a guard struct so CoUninitialize runs even on early returns.
|
||||
struct ComGuard;
|
||||
impl Drop for ComGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe { CoUninitialize() };
|
||||
}
|
||||
}
|
||||
let _com_guard = ComGuard;
|
||||
|
||||
let enumerator: IMMDeviceEnumerator = CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
|
||||
.context("CoCreateInstance(MMDeviceEnumerator) failed")?;
|
||||
|
||||
// eCommunications role (not eConsole) — this picks the device the user
|
||||
// has designated for communications in Sound Settings. It's the one
|
||||
// Windows's AEC is actually tuned for and the one Teams/Zoom use.
|
||||
let device = enumerator
|
||||
.GetDefaultAudioEndpoint(eCapture, eCommunications)
|
||||
.context("GetDefaultAudioEndpoint(eCapture, eCommunications) failed")?;
|
||||
|
||||
if let Ok(name) = device_name(&device) {
|
||||
info!(device = %name, "opening WASAPI communications capture endpoint");
|
||||
}
|
||||
|
||||
let audio_client: IAudioClient = device
|
||||
.Activate(CLSCTX_ALL, None)
|
||||
.context("IMMDevice::Activate(IAudioClient) failed")?;
|
||||
|
||||
// IAudioClient2 exposes SetClientProperties, which is the ONLY way to
|
||||
// set AudioCategory_Communications pre-Initialize. Calling it on the
|
||||
// base IAudioClient would not compile, and setting it after Initialize
|
||||
// is a no-op.
|
||||
let audio_client2: IAudioClient2 = audio_client
|
||||
.cast()
|
||||
.context("QueryInterface IAudioClient2 failed")?;
|
||||
|
||||
let mut props = AudioClientProperties {
|
||||
cbSize: std::mem::size_of::<AudioClientProperties>() as u32,
|
||||
bIsOffload: BOOL(0),
|
||||
eCategory: AudioCategory_Communications,
|
||||
// 0 = AUDCLNT_STREAMOPTIONS_NONE. The `windows` crate doesn't
|
||||
// export the enum constant in all versions, so use 0 directly.
|
||||
Options: Default::default(),
|
||||
};
|
||||
audio_client2
|
||||
.SetClientProperties(&mut props as *mut _)
|
||||
.context("SetClientProperties(AudioCategory_Communications) failed")?;
|
||||
|
||||
// Request 48 kHz mono i16 directly. AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM
|
||||
// tells Windows to do any needed format conversion inside the audio
|
||||
// engine rather than rejecting our format. SRC_DEFAULT_QUALITY picks
|
||||
// the standard Windows resampler quality (fine for voice).
|
||||
let wave_format = WAVEFORMATEX {
|
||||
wFormatTag: WAVE_FORMAT_PCM as u16,
|
||||
nChannels: 1,
|
||||
nSamplesPerSec: 48_000,
|
||||
nAvgBytesPerSec: 48_000 * 2, // 1 ch * 2 bytes/sample * 48000 Hz
|
||||
nBlockAlign: 2, // 1 ch * 2 bytes/sample
|
||||
wBitsPerSample: 16,
|
||||
cbSize: 0,
|
||||
};
|
||||
|
||||
// 1,000,000 hns = 100 ms buffer (hns = 100-nanosecond units). Windows
|
||||
// treats this as the minimum; the engine may give us a larger one.
|
||||
const BUFFER_DURATION_HNS: i64 = 1_000_000;
|
||||
|
||||
audio_client
|
||||
.Initialize(
|
||||
AUDCLNT_SHAREMODE_SHARED,
|
||||
AUDCLNT_STREAMFLAGS_EVENTCALLBACK
|
||||
| AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM
|
||||
| AUDCLNT_STREAMFLAGS_SRC_DEFAULT_QUALITY,
|
||||
BUFFER_DURATION_HNS,
|
||||
0,
|
||||
&wave_format,
|
||||
Some(&GUID::zeroed()),
|
||||
)
|
||||
.context(
|
||||
"IAudioClient::Initialize failed — Windows rejected communications-mode 48k mono i16",
|
||||
)?;
|
||||
|
||||
// Event-driven capture: Windows signals this handle each time a new
|
||||
// audio packet is available. We wait on it from the loop below.
|
||||
let event = CreateEventW(None, false, false, None).context("CreateEventW failed")?;
|
||||
audio_client
|
||||
.SetEventHandle(event)
|
||||
.context("SetEventHandle failed")?;
|
||||
|
||||
let capture_client: IAudioCaptureClient = audio_client
|
||||
.GetService()
|
||||
.context("IAudioClient::GetService(IAudioCaptureClient) failed")?;
|
||||
|
||||
audio_client.Start().context("IAudioClient::Start failed")?;
|
||||
|
||||
// Signal to the parent thread that init succeeded before entering the
|
||||
// hot loop. From this point on, errors get logged but don't propagate
|
||||
// back to the caller (they'd just cause the ring buffer to stop
|
||||
// filling, which the main thread detects as underruns).
|
||||
let _ = init_tx.send(Ok(()));
|
||||
info!("WASAPI communications-mode capture started with OS AEC enabled");
|
||||
|
||||
let mut logged_first_packet = false;
|
||||
|
||||
// Main capture loop. Exit when `running` goes false (from Drop or an
|
||||
// explicit stop() call).
|
||||
while running.load(Ordering::Relaxed) {
|
||||
// 200 ms timeout so we check `running` regularly even if the audio
|
||||
// engine stops delivering packets (e.g. device unplugged).
|
||||
let wait = WaitForSingleObject(event, 200);
|
||||
if wait.0 != WAIT_OBJECT_0.0 {
|
||||
// Timeout or failure — just loop and re-check running.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Drain all available packets. Windows may have queued more than
|
||||
// one since we were last scheduled.
|
||||
loop {
|
||||
let packet_length = match capture_client.GetNextPacketSize() {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
warn!("GetNextPacketSize failed: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
if packet_length == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut buffer_ptr: *mut u8 = std::ptr::null_mut();
|
||||
let mut num_frames: u32 = 0;
|
||||
let mut flags: u32 = 0;
|
||||
let mut device_position: u64 = 0;
|
||||
let mut qpc_position: u64 = 0;
|
||||
|
||||
if let Err(e) = capture_client.GetBuffer(
|
||||
&mut buffer_ptr,
|
||||
&mut num_frames,
|
||||
&mut flags,
|
||||
Some(&mut device_position),
|
||||
Some(&mut qpc_position),
|
||||
) {
|
||||
warn!("GetBuffer failed: {e}");
|
||||
break;
|
||||
}
|
||||
|
||||
if num_frames > 0 && !buffer_ptr.is_null() {
|
||||
if !logged_first_packet {
|
||||
info!(
|
||||
frames = num_frames,
|
||||
flags, "WASAPI capture: first packet received"
|
||||
);
|
||||
logged_first_packet = true;
|
||||
}
|
||||
|
||||
// Because we asked for 48 kHz mono i16, each frame is
|
||||
// exactly one i16. Windows's AUTOCONVERTPCM handles the
|
||||
// conversion from whatever the engine mix format is.
|
||||
let samples =
|
||||
std::slice::from_raw_parts(buffer_ptr as *const i16, num_frames as usize);
|
||||
ring.write(samples);
|
||||
}
|
||||
|
||||
if let Err(e) = capture_client.ReleaseBuffer(num_frames) {
|
||||
warn!("ReleaseBuffer failed: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WASAPI capture thread stopping");
|
||||
let _ = audio_client.Stop();
|
||||
let _ = CloseHandle(event);
|
||||
// _com_guard drops here, calling CoUninitialize.
|
||||
|
||||
// Silence INFINITE unused-import warning — it's referenced by the
|
||||
// `windows` crate's WaitForSingleObject alternative but we use the
|
||||
// 200 ms timeout variant instead. Explicit suppression for clarity.
|
||||
let _ = INFINITE;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Best-effort device ID string for logging. Grabbing the friendly name via
|
||||
/// PKEY_Device_FriendlyName requires IPropertyStore + PROPVARIANT plumbing
|
||||
/// that's far more ceremony than a log line justifies; the ID is already
|
||||
/// sufficient to confirm we opened the right endpoint.
|
||||
///
|
||||
/// Rust 2024 edition's `unsafe_op_in_unsafe_fn` lint requires explicit
|
||||
/// `unsafe { ... }` blocks inside `unsafe fn` bodies for each unsafe call,
|
||||
/// even though the whole function is already marked unsafe.
|
||||
unsafe fn device_name(
|
||||
device: &windows::Win32::Media::Audio::IMMDevice,
|
||||
) -> Result<String, anyhow::Error> {
|
||||
let id = unsafe { device.GetId() }.context("IMMDevice::GetId failed")?;
|
||||
Ok(unsafe { id.to_string() }.unwrap_or_else(|_| "<non-utf16>".to_string()))
|
||||
}
|
||||
@@ -6,8 +6,8 @@ use std::time::{Duration, Instant};
|
||||
|
||||
use wzp_crypto::ChaChaSession;
|
||||
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
|
||||
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::traits::{CryptoSession, FecDecoder, FecEncoder};
|
||||
|
||||
use crate::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
|
||||
@@ -170,7 +170,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
|
||||
// Collect all symbols: source + repair
|
||||
struct Symbol {
|
||||
index: u8,
|
||||
index: u16,
|
||||
is_repair: bool,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
@@ -180,7 +180,7 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
// For add_symbol we need to provide the raw data; the decoder pads internally
|
||||
total_source_bytes += sym.len();
|
||||
all_symbols.push(Symbol {
|
||||
index: i as u8,
|
||||
index: i as u16,
|
||||
is_repair: false,
|
||||
data: sym.clone(),
|
||||
});
|
||||
@@ -201,9 +201,13 @@ pub fn bench_fec_recovery(loss_pct: f32) -> FecResult {
|
||||
// Deterministic shuffle for reproducibility using a simple seed
|
||||
// We use a basic Fisher-Yates with a fixed-per-block seed
|
||||
let mut indices: Vec<usize> = (0..all_symbols.len()).collect();
|
||||
let mut seed = (block_idx as u64).wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let mut seed = (block_idx as u64)
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1);
|
||||
for i in (1..indices.len()).rev() {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
seed = seed
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
let j = (seed >> 33) as usize % (i + 1);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
@@ -259,17 +263,36 @@ pub fn bench_encrypt_decrypt() -> CryptoResult {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let header = b"bench-header";
|
||||
// Build valid v2 MediaHeader bytes — encrypt/decrypt now derive nonces from
|
||||
// header.seq and require a parseable MediaHeader (WIRE_SIZE bytes minimum).
|
||||
use wzp_proto::packet::MediaHeader;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
let mut total_bytes: usize = 0;
|
||||
|
||||
let start = Instant::now();
|
||||
for payload in &payloads {
|
||||
for (i, payload) in payloads.iter().enumerate() {
|
||||
let hdr = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: i as u32,
|
||||
timestamp: (i as u32).wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
hdr.write_to(&mut header_bytes);
|
||||
|
||||
let mut ciphertext = Vec::with_capacity(payload.len() + 16);
|
||||
encryptor.encrypt(header, payload, &mut ciphertext).unwrap();
|
||||
encryptor
|
||||
.encrypt(&header_bytes, payload, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut plaintext = Vec::with_capacity(payload.len());
|
||||
decryptor
|
||||
.decrypt(header, &ciphertext, &mut plaintext)
|
||||
.decrypt(&header_bytes, &ciphertext, &mut plaintext)
|
||||
.unwrap();
|
||||
|
||||
total_bytes += payload.len();
|
||||
|
||||
@@ -24,8 +24,14 @@ fn run_codec() {
|
||||
print_header("Codec Roundtrip (Opus 24kbps)");
|
||||
let r = bench::bench_codec_roundtrip();
|
||||
print_row("Frames", &format!("{}", r.frames));
|
||||
print_row("Encode total", &format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0));
|
||||
print_row("Decode total", &format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Encode total",
|
||||
&format!("{:.2} ms", r.total_encode.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Decode total",
|
||||
&format!("{:.2} ms", r.total_decode.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row("Avg encode", &format!("{:.1} us", r.avg_encode_us));
|
||||
print_row("Avg decode", &format!("{:.1} us", r.avg_decode_us));
|
||||
print_row("Throughput", &format!("{:.0} frames/sec", r.frames_per_sec));
|
||||
@@ -41,7 +47,10 @@ fn run_fec(loss_pct: f32) {
|
||||
print_row("Recovery rate", &format!("{:.1}%", r.recovery_rate_pct));
|
||||
print_row("Source bytes", &format!("{}", r.total_source_bytes));
|
||||
print_row("Repair (overhead) bytes", &format!("{}", r.overhead_bytes));
|
||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Total time",
|
||||
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_footer();
|
||||
}
|
||||
|
||||
@@ -49,7 +58,10 @@ fn run_crypto() {
|
||||
print_header("Crypto (ChaCha20-Poly1305)");
|
||||
let r = bench::bench_encrypt_decrypt();
|
||||
print_row("Packets", &format!("{}", r.packets));
|
||||
print_row("Total time", &format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0));
|
||||
print_row(
|
||||
"Total time",
|
||||
&format!("{:.2} ms", r.total_time.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row("Throughput", &format!("{:.0} pkt/sec", r.packets_per_sec));
|
||||
print_row("Bandwidth", &format!("{:.2} MB/sec", r.megabytes_per_sec));
|
||||
print_row("Avg latency", &format!("{:.2} us", r.avg_latency_us));
|
||||
@@ -60,9 +72,18 @@ fn run_pipeline() {
|
||||
print_header("Full Pipeline (E2E)");
|
||||
let r = bench::bench_full_pipeline();
|
||||
print_row("Frames", &format!("{}", r.frames));
|
||||
print_row("Encode pipeline", &format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0));
|
||||
print_row("Decode pipeline", &format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0));
|
||||
print_row("Avg E2E latency", &format!("{:.1} us/frame", r.avg_e2e_latency_us));
|
||||
print_row(
|
||||
"Encode pipeline",
|
||||
&format!("{:.2} ms", r.total_encode_pipeline.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Decode pipeline",
|
||||
&format!("{:.2} ms", r.total_decode_pipeline.as_secs_f64() * 1000.0),
|
||||
);
|
||||
print_row(
|
||||
"Avg E2E latency",
|
||||
&format!("{:.1} us/frame", r.avg_e2e_latency_us),
|
||||
);
|
||||
print_row("PCM in", &format!("{} bytes", r.pcm_bytes_in));
|
||||
print_row("Wire out", &format!("{} bytes", r.wire_bytes_out));
|
||||
print_row("Overhead ratio", &format!("{:.3}x", r.overhead_ratio));
|
||||
|
||||
347
crates/wzp-client/src/birthday.rs
Normal file
347
crates/wzp-client/src/birthday.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! Birthday attack for hard NAT traversal.
|
||||
//!
|
||||
//! When both peers are behind symmetric NATs with random port
|
||||
//! allocation, standard hole-punching fails because neither side
|
||||
//! can predict the other's external port. This module implements
|
||||
//! the birthday-paradox approach:
|
||||
//!
|
||||
//! 1. **Acceptor** opens N sockets, STUN-probes each to learn
|
||||
//! their external ports, reports them to the Dialer.
|
||||
//! 2. **Dialer** sprays QUIC connect attempts to the Acceptor's
|
||||
//! reported ports + random ports on the Acceptor's IP.
|
||||
//! 3. Birthday paradox: with N=64 ports and M=256 probes across
|
||||
//! 65536 ports, collision probability is high.
|
||||
//!
|
||||
//! In practice, the Acceptor's STUN-probed ports are known
|
||||
//! exactly (not random), so the Dialer targets them first —
|
||||
//! making this more like "spray-and-pray with a hit list" than
|
||||
//! a pure birthday attack.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::stun;
|
||||
|
||||
/// Configuration for the birthday attack.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BirthdayConfig {
|
||||
/// Number of sockets the Acceptor opens (default: 32).
|
||||
/// Each socket gets STUN-probed to learn its external port.
|
||||
/// More = higher chance of collision, but more resource usage.
|
||||
pub acceptor_ports: u16,
|
||||
/// Number of QUIC connect attempts the Dialer makes (default: 128).
|
||||
/// Spread across the Acceptor's known ports + random ports.
|
||||
pub dialer_probes: u16,
|
||||
/// Rate limit: ms between consecutive probes (default: 20ms = 50/s).
|
||||
pub probe_interval_ms: u16,
|
||||
/// Overall timeout for the birthday attack phase.
|
||||
pub timeout: Duration,
|
||||
/// STUN config for probing external ports.
|
||||
pub stun_config: stun::StunConfig,
|
||||
}
|
||||
|
||||
impl Default for BirthdayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
acceptor_ports: 32,
|
||||
dialer_probes: 128,
|
||||
probe_interval_ms: 20,
|
||||
timeout: Duration::from_secs(8),
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec!["stun.l.google.com:19302".into()],
|
||||
timeout: Duration::from_secs(2),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of the Acceptor's port-opening phase.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct AcceptorPorts {
|
||||
/// External IP (from STUN).
|
||||
pub external_ip: Option<Ipv4Addr>,
|
||||
/// List of (local_port, external_port) for each opened socket.
|
||||
pub ports: Vec<PortMapping>,
|
||||
/// How many sockets we attempted to open.
|
||||
pub attempted: u16,
|
||||
/// How many STUN probes succeeded.
|
||||
pub succeeded: u16,
|
||||
}
|
||||
|
||||
/// A single socket's local↔external port mapping.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct PortMapping {
|
||||
pub local_port: u16,
|
||||
pub external_port: u16,
|
||||
}
|
||||
|
||||
/// Open N sockets and STUN-probe each to discover external ports.
|
||||
///
|
||||
/// Returns the set of known external ports that the Dialer should
|
||||
/// target. Each socket stays open (bound) so the NAT mapping
|
||||
/// remains active until the returned `PortGuard` is dropped.
|
||||
///
|
||||
/// The sockets are returned so the caller can keep them alive
|
||||
/// during the attack. Dropping them closes the NAT pinholes.
|
||||
pub async fn open_acceptor_ports(
|
||||
config: &BirthdayConfig,
|
||||
) -> (AcceptorPorts, Vec<tokio::net::UdpSocket>) {
|
||||
let mut sockets = Vec::new();
|
||||
let mut mappings = Vec::new();
|
||||
let mut external_ip: Option<Ipv4Addr> = None;
|
||||
let mut succeeded: u16 = 0;
|
||||
|
||||
let stun_server = match config.stun_config.servers.first() {
|
||||
Some(s) => match stun::resolve_stun_server(s).await {
|
||||
Ok(a) => Some(a),
|
||||
Err(_) => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
for _ in 0..config.acceptor_ports {
|
||||
// Bind to random port
|
||||
let sock = match tokio::net::UdpSocket::bind("0.0.0.0:0").await {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let local_port = match sock.local_addr() {
|
||||
Ok(a) => a.port(),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// STUN probe to learn external port
|
||||
if let Some(stun_addr) = stun_server {
|
||||
match stun::stun_reflect(&sock, stun_addr, config.stun_config.timeout).await {
|
||||
Ok(ext_addr) => {
|
||||
if external_ip.is_none() {
|
||||
if let std::net::IpAddr::V4(ip) = ext_addr.ip() {
|
||||
external_ip = Some(ip);
|
||||
}
|
||||
}
|
||||
mappings.push(PortMapping {
|
||||
local_port,
|
||||
external_port: ext_addr.port(),
|
||||
});
|
||||
succeeded += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(local_port, error = %e, "birthday: STUN probe failed for socket");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sockets.push(sock);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
attempted = config.acceptor_ports,
|
||||
succeeded,
|
||||
external_ip = ?external_ip,
|
||||
"birthday: acceptor ports opened"
|
||||
);
|
||||
|
||||
let result = AcceptorPorts {
|
||||
external_ip,
|
||||
ports: mappings,
|
||||
attempted: config.acceptor_ports,
|
||||
succeeded,
|
||||
};
|
||||
|
||||
(result, sockets)
|
||||
}
|
||||
|
||||
/// Generate the list of target addresses for the Dialer to spray.
|
||||
///
|
||||
/// Priority order:
|
||||
/// 1. Acceptor's known external ports (from STUN probes) — highest hit rate
|
||||
/// 2. Random ports on the Acceptor's IP — birthday paradox fill
|
||||
pub fn generate_dialer_targets(
|
||||
acceptor_ip: Ipv4Addr,
|
||||
known_ports: &[u16],
|
||||
total_probes: u16,
|
||||
) -> Vec<SocketAddr> {
|
||||
let mut targets = Vec::with_capacity(total_probes as usize);
|
||||
|
||||
// First: all known ports (guaranteed targets)
|
||||
for &port in known_ports {
|
||||
targets.push(SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port));
|
||||
}
|
||||
|
||||
// Fill remaining with random ports (birthday attack)
|
||||
let remaining = total_probes.saturating_sub(known_ports.len() as u16);
|
||||
if remaining > 0 {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..remaining {
|
||||
let port = rng.gen_range(1024..=65535u16);
|
||||
let addr = SocketAddr::new(std::net::IpAddr::V4(acceptor_ip), port);
|
||||
if !targets.contains(&addr) {
|
||||
targets.push(addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targets
|
||||
}
|
||||
|
||||
/// Run the Dialer side of the birthday attack.
|
||||
///
|
||||
/// Sprays QUIC connection attempts at the target addresses.
|
||||
/// Returns the first successful connection, or None on timeout.
|
||||
pub async fn spray_dialer(
|
||||
endpoint: &wzp_transport::Endpoint,
|
||||
targets: &[SocketAddr],
|
||||
call_sni: &str,
|
||||
probe_interval: Duration,
|
||||
timeout: Duration,
|
||||
) -> Option<wzp_transport::QuinnTransport> {
|
||||
let start = Instant::now();
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
tracing::info!(
|
||||
target_count = targets.len(),
|
||||
interval_ms = probe_interval.as_millis(),
|
||||
timeout_s = timeout.as_secs(),
|
||||
"birthday: dialer starting spray"
|
||||
);
|
||||
|
||||
// Spray connects with rate limiting
|
||||
for (idx, &target) in targets.iter().enumerate() {
|
||||
if start.elapsed() >= timeout {
|
||||
break;
|
||||
}
|
||||
|
||||
let ep = endpoint.clone();
|
||||
let sni = call_sni.to_string();
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
set.spawn(async move {
|
||||
let result = wzp_transport::connect(&ep, target, &sni, client_cfg).await;
|
||||
(idx, target, result)
|
||||
});
|
||||
|
||||
// Rate limit — don't blast the NAT
|
||||
if idx < targets.len() - 1 {
|
||||
tokio::time::sleep(probe_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
spawned = set.len(),
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: all probes spawned, waiting for first success"
|
||||
);
|
||||
|
||||
// Wait for first success or all failures
|
||||
let deadline = start + timeout;
|
||||
while let Some(join_res) = tokio::select! {
|
||||
r = set.join_next() => r,
|
||||
_ = tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)) => None,
|
||||
} {
|
||||
match join_res {
|
||||
Ok((idx, target, Ok(conn))) => {
|
||||
tracing::info!(
|
||||
idx,
|
||||
%target,
|
||||
remote = %conn.remote_address(),
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: HIT! QUIC handshake succeeded"
|
||||
);
|
||||
set.abort_all();
|
||||
return Some(wzp_transport::QuinnTransport::new(conn));
|
||||
}
|
||||
Ok((idx, target, Err(e))) => {
|
||||
tracing::debug!(
|
||||
idx,
|
||||
%target,
|
||||
error = %e,
|
||||
"birthday: probe failed"
|
||||
);
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"birthday: all probes failed or timed out"
|
||||
);
|
||||
None
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generate_targets_known_ports_first() {
|
||||
let ip = Ipv4Addr::new(203, 0, 113, 5);
|
||||
let known = vec![10000, 10001, 10002];
|
||||
let targets = generate_dialer_targets(ip, &known, 10);
|
||||
|
||||
// Known ports should be first
|
||||
assert_eq!(targets[0].port(), 10000);
|
||||
assert_eq!(targets[1].port(), 10001);
|
||||
assert_eq!(targets[2].port(), 10002);
|
||||
// Rest are random
|
||||
assert!(targets.len() <= 10);
|
||||
// All target the right IP
|
||||
assert!(targets.iter().all(|a| a.ip() == std::net::IpAddr::V4(ip)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_no_known_all_random() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let targets = generate_dialer_targets(ip, &[], 50);
|
||||
assert!(!targets.is_empty());
|
||||
assert!(targets.len() <= 50);
|
||||
// All ports in valid range
|
||||
assert!(targets.iter().all(|a| a.port() >= 1024));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_more_known_than_total() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let known: Vec<u16> = (10000..10100).collect();
|
||||
let targets = generate_dialer_targets(ip, &known, 50);
|
||||
// All 100 known ports included even though total=50
|
||||
assert_eq!(targets.len(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_targets_dedup() {
|
||||
let ip = Ipv4Addr::new(10, 0, 0, 1);
|
||||
let targets = generate_dialer_targets(ip, &[], 100);
|
||||
// No duplicates
|
||||
let mut sorted = targets.clone();
|
||||
sorted.sort();
|
||||
sorted.dedup();
|
||||
assert_eq!(sorted.len(), targets.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config() {
|
||||
let cfg = BirthdayConfig::default();
|
||||
assert_eq!(cfg.acceptor_ports, 32);
|
||||
assert_eq!(cfg.dialer_probes, 128);
|
||||
assert!(cfg.timeout.as_secs() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acceptor_ports_serializes() {
|
||||
let result = AcceptorPorts {
|
||||
external_ip: Some(Ipv4Addr::new(203, 0, 113, 5)),
|
||||
ports: vec![PortMapping {
|
||||
local_port: 12345,
|
||||
external_port: 54321,
|
||||
}],
|
||||
attempted: 32,
|
||||
succeeded: 1,
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("54321"));
|
||||
assert!(json.contains("203.0.113.5"));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
use wzp_client::call::{CallConfig, CallDecoder, CallEncoder};
|
||||
use wzp_proto::MediaTransport;
|
||||
use wzp_proto::{MediaTransport, default_signal_version};
|
||||
|
||||
const FRAME_SAMPLES: usize = 960; // 20ms @ 48kHz
|
||||
|
||||
@@ -47,6 +47,13 @@ struct CliArgs {
|
||||
room: Option<String>,
|
||||
token: Option<String>,
|
||||
_metrics_file: Option<String>,
|
||||
version_check: bool,
|
||||
/// Connect to relay for persistent signaling (direct calls).
|
||||
signal: bool,
|
||||
/// Place a direct call to a fingerprint (requires --signal).
|
||||
call_target: Option<String>,
|
||||
/// Run network diagnostic (STUN, port mapping, relay latencies).
|
||||
netcheck: bool,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
@@ -88,12 +95,25 @@ fn parse_args() -> CliArgs {
|
||||
let mut room = None;
|
||||
let mut token = None;
|
||||
let mut metrics_file = None;
|
||||
let mut version_check = false;
|
||||
let mut relay_str = None;
|
||||
let mut signal = false;
|
||||
let mut call_target = None;
|
||||
let mut netcheck = false;
|
||||
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--live" => live = true,
|
||||
"--signal" => signal = true,
|
||||
"--call" => {
|
||||
i += 1;
|
||||
call_target = Some(
|
||||
args.get(i)
|
||||
.expect("--call requires a fingerprint")
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
"--send-tone" => {
|
||||
i += 1;
|
||||
send_tone_secs = Some(
|
||||
@@ -169,6 +189,12 @@ fn parse_args() -> CliArgs {
|
||||
);
|
||||
}
|
||||
"--sweep" => sweep = true,
|
||||
"--netcheck" => {
|
||||
netcheck = true;
|
||||
}
|
||||
"--version-check" => {
|
||||
version_check = true;
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
eprintln!("Usage: wzp-client [options] [relay-addr]");
|
||||
eprintln!();
|
||||
@@ -179,13 +205,19 @@ fn parse_args() -> CliArgs {
|
||||
eprintln!(" --record <file.raw> Record received audio to raw PCM file");
|
||||
eprintln!(" --echo-test <secs> Run automated echo quality test");
|
||||
eprintln!(" --drift-test <secs> Run automated clock-drift measurement");
|
||||
eprintln!(" --sweep Run jitter buffer parameter sweep (local, no network)");
|
||||
eprintln!(" --seed <hex> Identity seed (64 hex chars, featherChat compatible)");
|
||||
eprintln!(
|
||||
" --sweep Run jitter buffer parameter sweep (local, no network)"
|
||||
);
|
||||
eprintln!(
|
||||
" --seed <hex> Identity seed (64 hex chars, featherChat compatible)"
|
||||
);
|
||||
eprintln!(" --mnemonic <words...> Identity seed as BIP39 mnemonic (24 words)");
|
||||
eprintln!(" --room <name> Room name (hashed for privacy before sending)");
|
||||
eprintln!(" --token <token> featherChat bearer token for relay auth");
|
||||
eprintln!(" --metrics-file <path> Write JSONL telemetry to file (1 line/sec)");
|
||||
eprintln!(" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)");
|
||||
eprintln!(
|
||||
" (48kHz mono s16le, play with ffplay -f s16le -ar 48000 -ch_layout mono file.raw)"
|
||||
);
|
||||
eprintln!();
|
||||
eprintln!("Default relay: 127.0.0.1:4433");
|
||||
std::process::exit(0);
|
||||
@@ -221,6 +253,10 @@ fn parse_args() -> CliArgs {
|
||||
room,
|
||||
token,
|
||||
_metrics_file: metrics_file,
|
||||
version_check,
|
||||
signal,
|
||||
call_target,
|
||||
netcheck,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,6 +275,51 @@ async fn main() -> anyhow::Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// --netcheck: run network diagnostic and exit
|
||||
if cli.netcheck {
|
||||
let config = wzp_client::netcheck::NetcheckConfig {
|
||||
stun_config: wzp_client::stun::StunConfig::default(),
|
||||
relays: vec![("relay".into(), cli.relay_addr)],
|
||||
timeout: std::time::Duration::from_secs(5),
|
||||
test_portmap: true,
|
||||
test_ipv6: true,
|
||||
local_port: 0,
|
||||
};
|
||||
let report = wzp_client::netcheck::run_netcheck(&config).await;
|
||||
print!("{}", wzp_client::netcheck::format_report(&report));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// --version-check: query relay version over QUIC and exit
|
||||
if cli.version_check {
|
||||
let client_config = wzp_transport::client_config();
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse()?;
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let conn =
|
||||
wzp_transport::connect(&endpoint, cli.relay_addr, "version", client_config).await?;
|
||||
match conn.accept_uni().await {
|
||||
Ok(mut recv) => {
|
||||
let data = recv.read_to_end(256).await.unwrap_or_default();
|
||||
let version = String::from_utf8_lossy(&data);
|
||||
println!("{} {}", cli.relay_addr, version.trim());
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"relay {} does not support version query: {e}",
|
||||
cli.relay_addr
|
||||
);
|
||||
}
|
||||
}
|
||||
endpoint.close(0u32.into(), b"done");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// --signal mode: persistent signaling for direct calls
|
||||
if cli.signal {
|
||||
let seed = cli.resolve_seed();
|
||||
return run_signal_mode(cli.relay_addr, seed, cli.token, cli.call_target).await;
|
||||
}
|
||||
|
||||
let seed = cli.resolve_seed();
|
||||
|
||||
info!(
|
||||
@@ -250,12 +331,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
"WarzonePhone client"
|
||||
);
|
||||
|
||||
// Hash room name for SNI privacy (or "default" if none specified)
|
||||
// Use raw room name as SNI (consistent with Android + Desktop clients for federation)
|
||||
let sni = match &cli.room {
|
||||
Some(name) => {
|
||||
let hashed = wzp_crypto::hash_room_name(name);
|
||||
info!(room = %name, hashed = %hashed, "room name hashed for SNI");
|
||||
hashed
|
||||
info!(room = %name, "using room name as SNI");
|
||||
name.clone()
|
||||
}
|
||||
None => "default".to_string(),
|
||||
};
|
||||
@@ -267,16 +347,40 @@ async fn main() -> anyhow::Result<()> {
|
||||
"0.0.0.0:0".parse()?
|
||||
};
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let connection =
|
||||
wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
||||
let connection = wzp_transport::connect(&endpoint, cli.relay_addr, &sni, client_config).await?;
|
||||
|
||||
info!("Connected to relay");
|
||||
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(connection));
|
||||
|
||||
// Register shutdown handler so SIGTERM/SIGINT always closes QUIC cleanly.
|
||||
// Without this, killed clients leave zombie connections on the relay for ~30s.
|
||||
{
|
||||
let shutdown_transport = transport.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut sigterm =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM handler");
|
||||
let mut sigint =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
|
||||
.expect("failed to register SIGINT handler");
|
||||
tokio::select! {
|
||||
_ = sigterm.recv() => { info!("SIGTERM received, closing connection..."); }
|
||||
_ = sigint.recv() => { info!("SIGINT received, closing connection..."); }
|
||||
}
|
||||
// Close the QUIC connection immediately (APPLICATION_CLOSE frame).
|
||||
// Don't call process::exit — let the main task detect the closed
|
||||
// connection and perform clean shutdown (e.g., save recordings).
|
||||
shutdown_transport
|
||||
.connection()
|
||||
.close(0u32.into(), b"shutdown");
|
||||
});
|
||||
}
|
||||
|
||||
// Send auth token if provided (relay with --auth-url expects this first)
|
||||
if let Some(ref token) = cli.token {
|
||||
let auth = wzp_proto::SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: token.clone(),
|
||||
};
|
||||
transport.send_signal(&auth).await?;
|
||||
@@ -284,20 +388,29 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
// Crypto handshake — establishes verified identity + session key
|
||||
let _crypto_session = wzp_client::handshake::perform_handshake(
|
||||
let session = wzp_client::handshake::perform_handshake(
|
||||
&*transport,
|
||||
&seed.0,
|
||||
).await?;
|
||||
None, // alias — desktop client doesn't set one yet
|
||||
)
|
||||
.await?;
|
||||
info!("crypto handshake complete");
|
||||
|
||||
// Wrap the transport so all media I/O goes through AEAD encryption.
|
||||
let enc_transport: Arc<dyn wzp_proto::MediaTransport> = Arc::new(
|
||||
wzp_client::encrypted_transport::EncryptingTransport::new(transport.clone(), session),
|
||||
);
|
||||
|
||||
if cli.live {
|
||||
#[cfg(feature = "audio")]
|
||||
{
|
||||
return run_live(transport).await;
|
||||
return run_live(enc_transport).await;
|
||||
}
|
||||
#[cfg(not(feature = "audio"))]
|
||||
{
|
||||
anyhow::bail!("--live requires the 'audio' feature (build with: cargo build --features audio)");
|
||||
anyhow::bail!(
|
||||
"--live requires the 'audio' feature (build with: cargo build --features audio)"
|
||||
);
|
||||
}
|
||||
} else if let Some(secs) = cli.echo_test_secs {
|
||||
let result = wzp_client::echo_test::run_echo_test(&*transport, secs, 5.0).await?;
|
||||
@@ -314,14 +427,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
transport.close().await?;
|
||||
Ok(())
|
||||
} else if cli.send_tone_secs.is_some() || cli.send_file.is_some() || cli.record_file.is_some() {
|
||||
run_file_mode(transport, cli.send_tone_secs, cli.send_file, cli.record_file).await
|
||||
run_file_mode(
|
||||
enc_transport,
|
||||
cli.send_tone_secs,
|
||||
cli.send_file,
|
||||
cli.record_file,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
run_silence(transport).await
|
||||
run_silence(enc_transport).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send silence frames (connectivity test).
|
||||
async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
||||
async fn run_silence(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||
let config = CallConfig::default();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
|
||||
@@ -335,7 +454,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
for i in 0..250u32 {
|
||||
let packets = encoder.encode_frame(&pcm)?;
|
||||
for pkt in &packets {
|
||||
if pkt.header.is_repair {
|
||||
if pkt.header.is_repair() {
|
||||
total_repair += 1;
|
||||
} else {
|
||||
total_source += 1;
|
||||
@@ -360,7 +479,9 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
|
||||
info!(total_source, total_repair, total_bytes, "done — closing");
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
transport.send_signal(&hangup).await.ok();
|
||||
transport.close().await?;
|
||||
@@ -369,7 +490,7 @@ async fn run_silence(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::R
|
||||
|
||||
/// File/tone mode: send a test tone or audio file, and/or record received audio.
|
||||
async fn run_file_mode(
|
||||
transport: Arc<wzp_transport::QuinnTransport>,
|
||||
transport: Arc<dyn wzp_proto::MediaTransport>,
|
||||
send_tone_secs: Option<u32>,
|
||||
send_file: Option<String>,
|
||||
record_file: Option<String>,
|
||||
@@ -384,21 +505,28 @@ async fn run_file_mode(
|
||||
// Read raw PCM file (48kHz mono s16le)
|
||||
let bytes = match std::fs::read(path) {
|
||||
Ok(b) => b,
|
||||
Err(e) => { error!("read {path}: {e}"); return; }
|
||||
Err(e) => {
|
||||
error!("read {path}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let samples: Vec<i16> = bytes.chunks_exact(2)
|
||||
let samples: Vec<i16> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
||||
.collect();
|
||||
let duration = samples.len() as f64 / 48_000.0;
|
||||
info!(file = %path, duration = format!("{:.1}s", duration), "sending audio file");
|
||||
samples.chunks(FRAME_SAMPLES)
|
||||
samples
|
||||
.chunks(FRAME_SAMPLES)
|
||||
.filter(|c| c.len() == FRAME_SAMPLES)
|
||||
.map(|c| c.to_vec())
|
||||
.collect()
|
||||
} else if let Some(secs) = send_tone_secs {
|
||||
let total = (secs as u64) * 50;
|
||||
info!(seconds = secs, frames = total, "sending 440Hz tone");
|
||||
(0..total).map(|i| generate_sine_frame(440.0, 48_000, i)).collect()
|
||||
(0..total)
|
||||
.map(|i| generate_sine_frame(440.0, 48_000, i))
|
||||
.collect()
|
||||
} else {
|
||||
// No sending, just wait
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
@@ -422,7 +550,7 @@ async fn run_file_mode(
|
||||
}
|
||||
};
|
||||
for pkt in &packets {
|
||||
if pkt.header.is_repair {
|
||||
if pkt.header.is_repair() {
|
||||
total_repair += 1;
|
||||
} else {
|
||||
total_source += 1;
|
||||
@@ -470,7 +598,7 @@ async fn run_file_mode(
|
||||
result = recv_transport.recv_media() => {
|
||||
match result {
|
||||
Ok(Some(pkt)) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -511,7 +639,9 @@ async fn run_file_mode(
|
||||
|
||||
// Send Hangup signal so the relay knows we're done
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
transport.send_signal(&hangup).await.ok();
|
||||
|
||||
@@ -549,7 +679,7 @@ async fn run_file_mode(
|
||||
|
||||
/// Live mode: capture from mic, encode, send; receive, decode, play.
|
||||
#[cfg(feature = "audio")]
|
||||
async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Result<()> {
|
||||
async fn run_live(transport: Arc<dyn wzp_proto::MediaTransport>) -> anyhow::Result<()> {
|
||||
use wzp_client::audio_io::{AudioCapture, AudioPlayback};
|
||||
|
||||
let capture = AudioCapture::start()?;
|
||||
@@ -563,11 +693,21 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
|
||||
.spawn(move || {
|
||||
let config = CallConfig::default();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
let mut frame = vec![0i16; FRAME_SAMPLES];
|
||||
loop {
|
||||
let frame = match capture.read_frame() {
|
||||
Some(f) => f,
|
||||
None => break,
|
||||
};
|
||||
// Pull a full 20 ms frame from the capture ring. The ring
|
||||
// may return a partial read when the CPAL callback hasn't
|
||||
// produced enough samples yet — keep reading until we
|
||||
// accumulate a whole frame, sleeping briefly on empty
|
||||
// returns so we don't hot-spin the CPU.
|
||||
let mut filled = 0usize;
|
||||
while filled < FRAME_SAMPLES {
|
||||
let n = capture.ring().read(&mut frame[filled..]);
|
||||
filled += n;
|
||||
if n == 0 {
|
||||
std::thread::sleep(std::time::Duration::from_millis(2));
|
||||
}
|
||||
}
|
||||
let packets = match encoder.encode_frame(&frame) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
@@ -592,13 +732,19 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
|
||||
loop {
|
||||
match recv_transport.recv_media().await {
|
||||
Ok(Some(pkt)) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
// Only decode for source packets (1 source = 1 audio frame).
|
||||
// Repair packets feed the FEC decoder but don't produce audio.
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
playback.write_frame(&pcm_buf);
|
||||
// Push the decoded frame into the playback
|
||||
// ring. The CPAL output callback drains from
|
||||
// here on its own clock; if the ring is full
|
||||
// (rare in CLI live mode) the write returns
|
||||
// a short count and the tail is dropped,
|
||||
// which is the correct real-time behavior.
|
||||
playback.ring().write(&pcm_buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -623,3 +769,260 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
|
||||
info!("done");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persistent signaling mode for direct 1:1 calls.
|
||||
async fn run_signal_mode(
|
||||
relay_addr: SocketAddr,
|
||||
seed: wzp_crypto::Seed,
|
||||
token: Option<String>,
|
||||
call_target: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
let identity = seed.derive_identity();
|
||||
let pub_id = identity.public_identity();
|
||||
let fp = pub_id.fingerprint.to_string();
|
||||
let identity_pub = *pub_id.signing.as_bytes();
|
||||
info!(fingerprint = %fp, "signal mode");
|
||||
|
||||
// Connect to relay with SNI "_signal"
|
||||
let client_config = wzp_transport::client_config();
|
||||
let bind_addr: SocketAddr = if relay_addr.is_ipv6() {
|
||||
"[::]:0".parse()?
|
||||
} else {
|
||||
"0.0.0.0:0".parse()?
|
||||
};
|
||||
let endpoint = wzp_transport::create_endpoint(bind_addr, None)?;
|
||||
let conn = wzp_transport::connect(&endpoint, relay_addr, "_signal", client_config).await?;
|
||||
let transport = Arc::new(wzp_transport::QuinnTransport::new(conn));
|
||||
info!("connected to relay (signal channel)");
|
||||
|
||||
// Auth if token provided
|
||||
if let Some(ref tok) = token {
|
||||
transport
|
||||
.send_signal(&SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: tok.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Register presence (signature not verified in Phase 1)
|
||||
transport
|
||||
.send_signal(&SignalMessage::RegisterPresence {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
signature: vec![], // Phase 1: not verified
|
||||
alias: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Wait for ack
|
||||
match transport.recv_signal().await? {
|
||||
Some(SignalMessage::RegisterPresenceAck { success: true, .. }) => {
|
||||
info!(fingerprint = %fp, "registered on relay — waiting for calls");
|
||||
}
|
||||
Some(SignalMessage::RegisterPresenceAck {
|
||||
success: false,
|
||||
error,
|
||||
..
|
||||
}) => {
|
||||
anyhow::bail!("registration failed: {}", error.unwrap_or_default());
|
||||
}
|
||||
other => {
|
||||
anyhow::bail!("unexpected response: {other:?}");
|
||||
}
|
||||
}
|
||||
|
||||
// If --call specified, place the call
|
||||
if let Some(ref target) = call_target {
|
||||
info!(target = %target, "placing direct call...");
|
||||
let call_id = format!(
|
||||
"{:016x}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos()
|
||||
);
|
||||
|
||||
transport
|
||||
.send_signal(&SignalMessage::DirectCallOffer {
|
||||
version: default_signal_version(),
|
||||
caller_fingerprint: fp.clone(),
|
||||
caller_alias: None,
|
||||
target_fingerprint: target.clone(),
|
||||
call_id: call_id.clone(),
|
||||
identity_pub,
|
||||
ephemeral_pub: [0u8; 32], // Phase 1: not used for key exchange
|
||||
signature: vec![],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
// CLI client doesn't attempt hole-punching; always
|
||||
// relay-path.
|
||||
caller_reflexive_addr: None,
|
||||
caller_local_addrs: Vec::new(),
|
||||
caller_mapped_addr: None,
|
||||
caller_build_version: None,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Signal recv loop — handle incoming signals
|
||||
let signal_transport = transport.clone();
|
||||
let relay = relay_addr;
|
||||
let my_seed = seed.0;
|
||||
|
||||
loop {
|
||||
match signal_transport.recv_signal().await {
|
||||
Ok(Some(msg)) => match msg {
|
||||
SignalMessage::CallRinging { call_id, .. } => {
|
||||
info!(call_id = %call_id, "ringing...");
|
||||
}
|
||||
SignalMessage::DirectCallOffer {
|
||||
caller_fingerprint,
|
||||
caller_alias,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
info!(
|
||||
from = %caller_fingerprint,
|
||||
alias = ?caller_alias,
|
||||
call_id = %call_id,
|
||||
"incoming call — auto-accepting (generic)"
|
||||
);
|
||||
// Auto-accept for CLI testing
|
||||
let _ = signal_transport
|
||||
.send_signal(&SignalMessage::DirectCallAnswer {
|
||||
version: default_signal_version(),
|
||||
call_id,
|
||||
accept_mode: wzp_proto::CallAcceptMode::AcceptGeneric,
|
||||
identity_pub: Some(identity_pub),
|
||||
ephemeral_pub: None,
|
||||
signature: None,
|
||||
chosen_profile: Some(wzp_proto::QualityProfile::GOOD),
|
||||
// CLI auto-accept uses generic (privacy) mode,
|
||||
// so callee addr stays hidden from the caller.
|
||||
callee_reflexive_addr: None,
|
||||
callee_local_addrs: Vec::new(),
|
||||
callee_mapped_addr: None,
|
||||
callee_build_version: None,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
SignalMessage::DirectCallAnswer {
|
||||
call_id,
|
||||
accept_mode,
|
||||
..
|
||||
} => {
|
||||
info!(call_id = %call_id, mode = ?accept_mode, "call answered");
|
||||
}
|
||||
SignalMessage::CallSetup {
|
||||
call_id,
|
||||
room,
|
||||
relay_addr: setup_relay,
|
||||
peer_direct_addr: _,
|
||||
peer_local_addrs: _,
|
||||
peer_mapped_addr: _,
|
||||
..
|
||||
} => {
|
||||
info!(call_id = %call_id, room = %room, relay = %setup_relay, "call setup — connecting to media room");
|
||||
|
||||
// Connect to the media room
|
||||
let media_relay: SocketAddr = setup_relay.parse().unwrap_or(relay);
|
||||
let media_cfg = wzp_transport::client_config();
|
||||
match wzp_transport::connect(&endpoint, media_relay, &room, media_cfg).await {
|
||||
Ok(media_conn) => {
|
||||
let media_transport =
|
||||
Arc::new(wzp_transport::QuinnTransport::new(media_conn));
|
||||
|
||||
// Crypto handshake
|
||||
match wzp_client::handshake::perform_handshake(
|
||||
&*media_transport,
|
||||
&my_seed,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_session) => {
|
||||
info!(
|
||||
"media connected — sending tone (press Ctrl+C to hang up)"
|
||||
);
|
||||
|
||||
// Simple tone sender for testing
|
||||
let mt = media_transport.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
let config = wzp_client::call::CallConfig::default();
|
||||
let mut encoder =
|
||||
wzp_client::call::CallEncoder::new(&config);
|
||||
let duration = tokio::time::Duration::from_millis(20);
|
||||
loop {
|
||||
let pcm: Vec<i16> = (0..FRAME_SAMPLES)
|
||||
.map(|_| 0i16) // silence — could be tone
|
||||
.collect();
|
||||
if let Ok(pkts) = encoder.encode_frame(&pcm) {
|
||||
for pkt in &pkts {
|
||||
if mt.send_media(pkt).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for hangup or ctrl+c
|
||||
loop {
|
||||
tokio::select! {
|
||||
sig = signal_transport.recv_signal() => {
|
||||
match sig {
|
||||
Ok(Some(SignalMessage::Hangup { .. })) => {
|
||||
info!("remote hung up");
|
||||
break;
|
||||
}
|
||||
Ok(None) | Err(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("hanging up...");
|
||||
let _ = signal_transport.send_signal(&SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
}).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
send_task.abort();
|
||||
media_transport.close().await.ok();
|
||||
info!("call ended");
|
||||
}
|
||||
Err(e) => error!("media handshake failed: {e}"),
|
||||
}
|
||||
}
|
||||
Err(e) => error!("media connect failed: {e}"),
|
||||
}
|
||||
}
|
||||
SignalMessage::Hangup { reason, .. } => {
|
||||
info!(reason = ?reason, "call ended by remote");
|
||||
}
|
||||
SignalMessage::Pong { .. } => {}
|
||||
other => {
|
||||
info!("signal: {:?}", std::mem::discriminant(&other));
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
info!("signal connection closed");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("signal error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transport.close().await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ pub async fn run_drift_test(
|
||||
}
|
||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -180,7 +180,7 @@ pub async fn run_drift_test(
|
||||
while Instant::now() < drain_deadline {
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(_n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -234,7 +234,10 @@ pub fn print_drift_report(result: &DriftResult) {
|
||||
println!();
|
||||
println!("Expected duration: {} ms", result.expected_duration_ms);
|
||||
println!("Actual duration: {} ms", result.actual_duration_ms);
|
||||
println!("Drift: {} ms ({:+.4}%)", result.drift_ms, result.drift_pct);
|
||||
println!(
|
||||
"Drift: {} ms ({:+.4}%)",
|
||||
result.drift_ms, result.drift_pct
|
||||
);
|
||||
println!();
|
||||
|
||||
// Interpretation
|
||||
@@ -246,9 +249,15 @@ pub fn print_drift_report(result: &DriftResult) {
|
||||
} else if abs_drift < 20 {
|
||||
println!("Result: GOOD -- drift is within acceptable bounds (<20 ms).");
|
||||
} else if abs_drift < 100 {
|
||||
println!("Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.", abs_drift);
|
||||
println!(
|
||||
"Result: FAIR -- noticeable drift ({} ms). Clock sync may be needed.",
|
||||
abs_drift
|
||||
);
|
||||
} else {
|
||||
println!("Result: POOR -- significant drift ({} ms). Investigate clock sources.", abs_drift);
|
||||
println!(
|
||||
"Result: POOR -- significant drift ({} ms). Investigate clock sources.",
|
||||
abs_drift
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
976
crates/wzp-client/src/dual_path.rs
Normal file
976
crates/wzp-client/src/dual_path.rs
Normal file
@@ -0,0 +1,976 @@
|
||||
//! Phase 3.5 — dual-path QUIC connect race for P2P hole-punching.
|
||||
//!
|
||||
//! When both peers advertised reflex addrs in the
|
||||
//! DirectCallOffer/Answer flow, the relay cross-wires them into
|
||||
//! `CallSetup.peer_direct_addr`. This module races a direct QUIC
|
||||
//! handshake against the existing relay dial and returns whichever
|
||||
//! completes first — with automatic drop of the loser via
|
||||
//! `tokio::select!`.
|
||||
//!
|
||||
//! Role determination is deterministic and symmetric
|
||||
//! (`wzp_client::reflect::determine_role`): whichever peer has the
|
||||
//! lexicographically smaller reflex addr becomes the **Acceptor**
|
||||
//! (listens on a server-capable endpoint), the other becomes the
|
||||
//! **Dialer** (dials the peer's addr). Because the rule is
|
||||
//! identical on both sides, the Acceptor's inbound QUIC session
|
||||
//! and the Dialer's outbound are the SAME connection — no
|
||||
//! negotiation needed, no two-conns-per-call confusion.
|
||||
//!
|
||||
//! Timeout policy:
|
||||
//! - Direct path: 2s from the start of `race`. Cone-NAT hole-punch
|
||||
//! typically completes in < 500ms on a LAN; 2s gives us tolerance
|
||||
//! for a single QUIC Initial retry on unreliable networks.
|
||||
//! - Relay path: 10s (existing behavior elsewhere in the codebase).
|
||||
//! - Overall: `tokio::select!` returns as soon as either succeeds.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::reflect::Role;
|
||||
use wzp_transport::QuinnTransport;
|
||||
|
||||
/// Which path won the race. Used by the `connect` command for
|
||||
/// logging + (in the future) metrics.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WinningPath {
|
||||
Direct,
|
||||
Relay,
|
||||
}
|
||||
|
||||
/// Diagnostic info for a single candidate dial attempt.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct CandidateDiag {
|
||||
pub index: usize,
|
||||
pub addr: String,
|
||||
pub result: String, // "ok", "skipped:ipv6", "error:..."
|
||||
pub elapsed_ms: Option<u32>,
|
||||
}
|
||||
|
||||
/// Phase 6: the race now returns BOTH transports (when available)
|
||||
/// so the connect command can negotiate with the peer before
|
||||
/// committing. The negotiation decides which transport to use
|
||||
/// based on whether BOTH sides report `direct_ok = true`.
|
||||
pub struct RaceResult {
|
||||
/// The direct P2P transport, if the direct path completed.
|
||||
/// `None` if the direct dial/accept failed or timed out.
|
||||
pub direct_transport: Option<Arc<QuinnTransport>>,
|
||||
/// The relay transport, if the relay dial completed.
|
||||
/// `None` if the relay dial failed (shouldn't happen in
|
||||
/// practice since relay is always reachable).
|
||||
pub relay_transport: Option<Arc<QuinnTransport>>,
|
||||
/// Which future completed first in the local race.
|
||||
/// Informational — the actual path used is decided by the
|
||||
/// Phase 6 negotiation after both sides exchange reports.
|
||||
pub local_winner: WinningPath,
|
||||
/// Per-candidate diagnostic info for debugging.
|
||||
pub candidate_diags: Vec<CandidateDiag>,
|
||||
}
|
||||
|
||||
/// Attempt a direct QUIC connection to the peer in parallel with
|
||||
/// the relay dial and return the winning `QuinnTransport`.
|
||||
///
|
||||
/// `role` selects the direction of the direct attempt:
|
||||
/// - `Role::Acceptor` creates a server-capable endpoint and waits
|
||||
/// for the peer to dial in.
|
||||
/// - `Role::Dialer` creates a client-only endpoint and dials
|
||||
/// `peer_direct_addr`.
|
||||
///
|
||||
/// The relay path is always attempted in parallel as a fallback so
|
||||
/// the race ALWAYS produces a working transport unless both paths
|
||||
/// genuinely fail (network partition). Returns
|
||||
/// `Err(anyhow::anyhow!(...))` if both paths fail within the
|
||||
/// timeout.
|
||||
/// Phase 5.5 candidate bundle — full ICE-ish candidate list for
|
||||
/// the peer. The race tries them all in parallel alongside the
|
||||
/// relay path. At minimum this should contain the peer's
|
||||
/// server-reflexive address; `local_addrs` carries LAN host
|
||||
/// candidates gathered from their physical interfaces.
|
||||
///
|
||||
/// Empty is valid: the D-role has nothing to dial and the race
|
||||
/// reduces to "relay only" + (if A-role) accepting on the
|
||||
/// shared endpoint.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PeerCandidates {
|
||||
/// Peer's server-reflexive address (Phase 3). `None` if the
|
||||
/// peer didn't advertise one.
|
||||
pub reflexive: Option<SocketAddr>,
|
||||
/// Peer's LAN host addresses (Phase 5.5). Tried first on
|
||||
/// same-LAN pairs — direct dials to these bypass the NAT
|
||||
/// entirely.
|
||||
pub local: Vec<SocketAddr>,
|
||||
/// Phase 8 (Tailscale-inspired): peer's port-mapped external
|
||||
/// address from NAT-PMP/PCP/UPnP. When the router supports
|
||||
/// port mapping, this gives a stable external address even
|
||||
/// behind symmetric NATs.
|
||||
pub mapped: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl PeerCandidates {
|
||||
/// Flatten into the list of addrs the D-role should dial.
|
||||
/// Order: LAN host candidates first (fastest when they
|
||||
/// work), then port-mapped (stable even behind symmetric
|
||||
/// NATs), then reflexive (covers the non-LAN case).
|
||||
pub fn dial_order(&self) -> Vec<SocketAddr> {
|
||||
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||
out.extend(self.local.iter().copied());
|
||||
// Port-mapped address goes before reflexive — it's
|
||||
// more reliable on symmetric NATs where the reflexive
|
||||
// addr might not match what the peer actually sees.
|
||||
if let Some(a) = self.mapped {
|
||||
if !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
if let Some(a) = self.reflexive {
|
||||
if !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Smart dial order: filters out candidates that can't possibly
|
||||
/// work given our own reflexive address.
|
||||
///
|
||||
/// - **LAN candidates**: only included if peer's public IP
|
||||
/// matches ours (same network). Private IPs are unreachable
|
||||
/// cross-network.
|
||||
/// - **IPv6 candidates**: stripped entirely (Phase 7 disabled).
|
||||
/// - **Reflexive + mapped**: always included.
|
||||
pub fn smart_dial_order(&self, own_reflexive: Option<&SocketAddr>) -> Vec<SocketAddr> {
|
||||
let own_public_ip = own_reflexive.map(|a| a.ip());
|
||||
let peer_public_ip = self.reflexive.map(|a| a.ip());
|
||||
let same_network = match (own_public_ip, peer_public_ip) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
let mut out = Vec::with_capacity(self.local.len() + 2);
|
||||
|
||||
// LAN candidates only when on the same network.
|
||||
if same_network {
|
||||
for addr in &self.local {
|
||||
if !addr.is_ipv6() {
|
||||
out.push(*addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Port-mapped (always useful — it's a public addr).
|
||||
if let Some(a) = self.mapped {
|
||||
if !a.is_ipv6() && !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
|
||||
// Reflexive (always useful — it's the peer's public addr).
|
||||
if let Some(a) = self.reflexive {
|
||||
if !a.is_ipv6() && !out.contains(&a) {
|
||||
out.push(a);
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Is there anything for the D-role to dial? If not, the
|
||||
/// race reduces to relay-only.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.reflexive.is_none() && self.local.is_empty() && self.mapped.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn race(
|
||||
role: Role,
|
||||
peer_candidates: PeerCandidates,
|
||||
relay_addr: SocketAddr,
|
||||
room_sni: String,
|
||||
call_sni: String,
|
||||
// Our own reflexive address — used to filter LAN candidates
|
||||
// that can't work cross-network.
|
||||
own_reflexive: Option<SocketAddr>,
|
||||
// Phase 5: when `Some`, reuse this endpoint for BOTH the
|
||||
// direct-path branch AND the relay dial. Pass the signal
|
||||
// endpoint. The endpoint MUST be server-capable (created
|
||||
// with a server config) for the A-role accept branch to
|
||||
// work.
|
||||
//
|
||||
// When `None`, falls back to fresh endpoints per role.
|
||||
// Used by tests.
|
||||
shared_endpoint: Option<wzp_transport::Endpoint>,
|
||||
// Phase 7: dedicated IPv6 endpoint with IPV6_V6ONLY=1.
|
||||
// When `Some`, A-role accepts on both v4+v6, D-role routes
|
||||
// each candidate to its matching-AF endpoint. When `None`,
|
||||
// IPv6 candidates are skipped (IPv4-only, pre-Phase-7).
|
||||
ipv6_endpoint: Option<wzp_transport::Endpoint>,
|
||||
) -> anyhow::Result<RaceResult> {
|
||||
// Rustls provider must be installed before any quinn endpoint
|
||||
// is created. Install attempt is idempotent.
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Shared diagnostic collector for per-candidate results.
|
||||
let diags_collector: Arc<std::sync::Mutex<Vec<CandidateDiag>>> =
|
||||
Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
|
||||
// Build the direct-path endpoint + future based on role.
|
||||
//
|
||||
// A-role: one accept future on the shared endpoint. The
|
||||
// first incoming QUIC connection wins — we don't care
|
||||
// which peer candidate the dialer used to reach us.
|
||||
//
|
||||
// D-role: N parallel dial futures, one per peer candidate
|
||||
// (all LAN host addrs + the reflex addr), consolidated
|
||||
// into a single direct_fut via FuturesUnordered-style
|
||||
// "first OK wins" semantics. The first successful dial
|
||||
// becomes the direct path; the losers are dropped (quinn
|
||||
// will abort the in-flight handshakes via the dropped
|
||||
// Connecting futures).
|
||||
//
|
||||
// Either way, direct_fut resolves to a single QuinnTransport
|
||||
// (or an error) and is raced against the relay_fut by the
|
||||
// outer tokio::select!.
|
||||
let direct_ep: wzp_transport::Endpoint;
|
||||
let direct_fut: std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = anyhow::Result<QuinnTransport>> + Send>,
|
||||
>;
|
||||
|
||||
match role {
|
||||
Role::Acceptor => {
|
||||
let ep = match shared_endpoint.clone() {
|
||||
Some(ep) => {
|
||||
tracing::info!(
|
||||
local_addr = ?ep.local_addr().ok(),
|
||||
"dual_path: A-role reusing shared endpoint for accept"
|
||||
);
|
||||
ep
|
||||
}
|
||||
None => {
|
||||
let (sc, _cert_der) = wzp_transport::server_config();
|
||||
// 0.0.0.0:0 = IPv4 socket. [::]:0 dual-stack was
|
||||
// tried but breaks on Android devices where
|
||||
// IPV6_V6ONLY=1 (default on some kernels) —
|
||||
// IPv4 candidates silently fail. IPv6 host
|
||||
// candidates are skipped for now; they need a
|
||||
// dedicated IPv6 socket alongside the v4 one
|
||||
// (like WebRTC's dual-socket approach).
|
||||
let bind: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let fresh = wzp_transport::create_endpoint(bind, Some(sc))?;
|
||||
tracing::info!(
|
||||
local_addr = ?fresh.local_addr().ok(),
|
||||
"dual_path: A-role fresh endpoint up, awaiting peer dial"
|
||||
);
|
||||
fresh
|
||||
}
|
||||
};
|
||||
let ep_for_fut = ep.clone();
|
||||
// Phase 7: IPv6 accept temporarily disabled (same reason
|
||||
// as dial — IPv6 connections die on datagram send).
|
||||
// Accept on IPv4 shared endpoint only.
|
||||
let _v6_ep_unused = ipv6_endpoint.clone();
|
||||
// Collect peer addrs for NAT tickle (Acceptor-side).
|
||||
let tickle_addrs: Vec<SocketAddr> = peer_candidates
|
||||
.smart_dial_order(own_reflexive.as_ref())
|
||||
.into_iter()
|
||||
.filter(|a| !a.ip().is_loopback() && !a.ip().is_unspecified())
|
||||
.collect();
|
||||
direct_fut = Box::pin(async move {
|
||||
// NAT tickle: send a small UDP packet to each of the
|
||||
// Dialer's candidate addresses FROM our shared endpoint.
|
||||
// This opens our NAT's pinhole for return traffic from
|
||||
// those IPs — critical for address-restricted NATs that
|
||||
// only allow inbound from IPs they've seen outbound
|
||||
// traffic to. Without this, the Dialer's QUIC Initial
|
||||
// gets dropped by our NAT.
|
||||
if !tickle_addrs.is_empty() {
|
||||
if let Ok(local_addr) = ep_for_fut.local_addr() {
|
||||
// Send a tickle to each peer candidate address
|
||||
// to open our NAT for return traffic from that IP.
|
||||
//
|
||||
// We use a socket2 socket with SO_REUSEADDR +
|
||||
// SO_REUSEPORT on the SAME port as the quinn
|
||||
// endpoint. This is necessary because quinn
|
||||
// already holds the port — a plain bind() would
|
||||
// fail with EADDRINUSE.
|
||||
let tickle_result: Result<(), String> = (|| {
|
||||
use std::net::UdpSocket as StdUdpSocket;
|
||||
let sock = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.map_err(|e| format!("socket: {e}"))?;
|
||||
sock.set_reuse_address(true)
|
||||
.map_err(|e| format!("reuseaddr: {e}"))?;
|
||||
// macOS/BSD/Linux also need SO_REUSEPORT
|
||||
#[cfg(any(
|
||||
target_os = "macos",
|
||||
target_os = "linux",
|
||||
target_os = "android"
|
||||
))]
|
||||
{
|
||||
// socket2 exposes set_reuse_port on unix
|
||||
unsafe {
|
||||
let optval: libc::c_int = 1;
|
||||
libc::setsockopt(
|
||||
std::os::unix::io::AsRawFd::as_raw_fd(&sock),
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_REUSEPORT,
|
||||
&optval as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||
);
|
||||
}
|
||||
}
|
||||
sock.set_nonblocking(true)
|
||||
.map_err(|e| format!("nonblock: {e}"))?;
|
||||
let bind_addr: SocketAddr = SocketAddr::new(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
|
||||
local_addr.port(),
|
||||
);
|
||||
sock.bind(&bind_addr.into())
|
||||
.map_err(|e| format!("bind :{}: {e}", local_addr.port()))?;
|
||||
let std_sock: StdUdpSocket = sock.into();
|
||||
for addr in &tickle_addrs {
|
||||
let _ = std_sock.send_to(&[0u8; 1], addr);
|
||||
tracing::info!(
|
||||
%addr,
|
||||
local_port = local_addr.port(),
|
||||
"dual_path: A-role sent NAT tickle"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})();
|
||||
if let Err(e) = tickle_result {
|
||||
tracing::warn!(error = %e, "dual_path: A-role NAT tickle failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accept loop: retry if we get a stale/closed
|
||||
// connection from a previous call. Max 3 retries
|
||||
// to avoid spinning until the race timeout.
|
||||
const MAX_STALE: usize = 3;
|
||||
let mut stale_count: usize = 0;
|
||||
loop {
|
||||
let conn = wzp_transport::accept(&ep_for_fut)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("direct accept: {e}"))?;
|
||||
|
||||
if let Some(reason) = conn.close_reason() {
|
||||
// Explicitly close so the peer gets a
|
||||
// close frame instead of idle timeout.
|
||||
conn.close(0u32.into(), b"stale");
|
||||
stale_count += 1;
|
||||
tracing::warn!(
|
||||
remote = %conn.remote_address(),
|
||||
stable_id = conn.stable_id(),
|
||||
stale_count,
|
||||
?reason,
|
||||
"dual_path: A-role skipping stale connection"
|
||||
);
|
||||
if stale_count >= MAX_STALE {
|
||||
return Err(anyhow::anyhow!(
|
||||
"A-role: {stale_count} stale connections, aborting"
|
||||
));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let has_dgram = conn.max_datagram_size().is_some();
|
||||
tracing::info!(
|
||||
remote = %conn.remote_address(),
|
||||
stable_id = conn.stable_id(),
|
||||
has_dgram,
|
||||
"dual_path: A-role accepted direct connection"
|
||||
);
|
||||
|
||||
break Ok(QuinnTransport::new(conn));
|
||||
}
|
||||
});
|
||||
direct_ep = ep;
|
||||
}
|
||||
Role::Dialer => {
|
||||
let ep = match shared_endpoint.clone() {
|
||||
Some(ep) => {
|
||||
tracing::info!(
|
||||
local_addr = ?ep.local_addr().ok(),
|
||||
candidates = ?peer_candidates.dial_order(),
|
||||
"dual_path: D-role reusing shared endpoint to dial peer candidates"
|
||||
);
|
||||
ep
|
||||
}
|
||||
None => {
|
||||
// 0.0.0.0:0 = IPv4 socket. [::]:0 dual-stack was
|
||||
// tried but breaks on Android devices where
|
||||
// IPV6_V6ONLY=1 (default on some kernels) —
|
||||
// IPv4 candidates silently fail. IPv6 host
|
||||
// candidates are skipped for now; they need a
|
||||
// dedicated IPv6 socket alongside the v4 one
|
||||
// (like WebRTC's dual-socket approach).
|
||||
let bind: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let fresh = wzp_transport::create_endpoint(bind, None)?;
|
||||
tracing::info!(
|
||||
local_addr = ?fresh.local_addr().ok(),
|
||||
candidates = ?peer_candidates.dial_order(),
|
||||
"dual_path: D-role fresh endpoint up, dialing peer candidates"
|
||||
);
|
||||
fresh
|
||||
}
|
||||
};
|
||||
let ep_for_fut = ep.clone();
|
||||
let _v6_ep_for_dial = ipv6_endpoint.clone();
|
||||
let dial_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||
let sni = call_sni.clone();
|
||||
let diags = diags_collector.clone();
|
||||
direct_fut = Box::pin(async move {
|
||||
if dial_order.is_empty() {
|
||||
// No candidates — the race reduces to
|
||||
// relay-only. Surface a stable error so the
|
||||
// outer select falls through to relay_fut
|
||||
// without a spurious "direct failed" warning.
|
||||
// Use a pending future that never resolves so
|
||||
// the select's "other side wins" branch is
|
||||
// the natural outcome.
|
||||
std::future::pending::<anyhow::Result<QuinnTransport>>().await
|
||||
} else {
|
||||
// Fan out N parallel dials via JoinSet. First
|
||||
// `Ok` wins; `Err` from a single candidate is
|
||||
// not fatal — we wait for the others. Only
|
||||
// when ALL have failed do we return Err.
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
for (idx, candidate) in dial_order.iter().enumerate() {
|
||||
// Phase 7: route each candidate to the
|
||||
// endpoint matching its address family.
|
||||
let candidate = *candidate;
|
||||
// Phase 7: IPv6 dials temporarily disabled.
|
||||
// IPv6 QUIC handshakes succeed but the
|
||||
// connection dies immediately on datagram
|
||||
// send ("connection lost"). Root cause is
|
||||
// likely router-level IPv6 UDP filtering.
|
||||
// Re-enable once IPv6 datagram delivery is
|
||||
// verified on target networks.
|
||||
if candidate.is_ipv6() {
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
"dual_path: skipping IPv6 candidate (disabled)"
|
||||
);
|
||||
if let Ok(mut d) = diags.lock() {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: candidate.to_string(),
|
||||
result: "skipped:ipv6".into(),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let ep = ep_for_fut.clone();
|
||||
let client_cfg = wzp_transport::client_config();
|
||||
let sni = sni.clone();
|
||||
let diags_inner = diags.clone();
|
||||
set.spawn(async move {
|
||||
let start = std::time::Instant::now();
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
"dual_path: dialing candidate"
|
||||
);
|
||||
let result =
|
||||
wzp_transport::connect(&ep, candidate, &sni, client_cfg).await;
|
||||
let elapsed = start.elapsed().as_millis() as u32;
|
||||
let diag_result = match &result {
|
||||
Ok(_) => "ok".to_string(),
|
||||
Err(e) => format!("error:{e}"),
|
||||
};
|
||||
if let Ok(mut d) = diags_inner.lock() {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: candidate.to_string(),
|
||||
result: diag_result,
|
||||
elapsed_ms: Some(elapsed),
|
||||
});
|
||||
}
|
||||
(idx, candidate, result)
|
||||
});
|
||||
}
|
||||
let mut last_err: Option<String> = None;
|
||||
while let Some(join_res) = set.join_next().await {
|
||||
let (idx, candidate, dial_res) = match join_res {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
last_err = Some(format!("join {e}"));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match dial_res {
|
||||
Ok(conn) => {
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
remote = %conn.remote_address(),
|
||||
stable_id = conn.stable_id(),
|
||||
"dual_path: direct dial succeeded on candidate"
|
||||
);
|
||||
// Abort the remaining in-flight
|
||||
// dials so they don't complete
|
||||
// and leak QUIC sessions.
|
||||
set.abort_all();
|
||||
return Ok(QuinnTransport::new(conn));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::info!(
|
||||
%candidate,
|
||||
candidate_idx = idx,
|
||||
error = %e,
|
||||
"dual_path: direct dial failed, trying others"
|
||||
);
|
||||
last_err = Some(format!("candidate {candidate}: {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow::anyhow!(
|
||||
"all {} direct candidates failed; last: {}",
|
||||
dial_order.len(),
|
||||
last_err.unwrap_or_else(|| "n/a".into())
|
||||
))
|
||||
}
|
||||
});
|
||||
direct_ep = ep;
|
||||
}
|
||||
}
|
||||
|
||||
// Relay path: classic dial to the relay's media room. Phase 5:
|
||||
// reuse the shared endpoint here too so MikroTik-style NATs
|
||||
// keep a stable external port across all flows from this
|
||||
// client. Falls back to a fresh endpoint when not shared.
|
||||
let relay_ep = match shared_endpoint.clone() {
|
||||
Some(ep) => ep,
|
||||
None => {
|
||||
let relay_bind: SocketAddr = "[::]:0".parse().unwrap();
|
||||
wzp_transport::create_endpoint(relay_bind, None)?
|
||||
}
|
||||
};
|
||||
let relay_ep_for_fut = relay_ep.clone();
|
||||
let relay_client_cfg = wzp_transport::client_config();
|
||||
let relay_sni = room_sni.clone();
|
||||
// Phase 5.5 direct-path head-start: hold the relay dial for
|
||||
// 500ms before attempting it. On same-LAN cone-NAT pairs the
|
||||
// direct dial finishes in ~30-100ms, so giving direct a 500ms
|
||||
// head start means direct reliably wins when it's going to
|
||||
// work at all. The worst case adds 500ms to the fall-back-
|
||||
// to-relay scenario, which is imperceptible for users on
|
||||
// setups where direct isn't available anyway.
|
||||
//
|
||||
// Prior behavior (immediate race) caused the relay to win
|
||||
// ~105ms races on a MikroTik LAN because:
|
||||
// - Acceptor role's direct_fut = accept() can only fire
|
||||
// when the peer has completed its outbound LAN dial
|
||||
// - Dialer role's parallel LAN dials need the peer's
|
||||
// CallSetup processed + the race started on the other
|
||||
// side before they can reach us
|
||||
// - Meanwhile relay_fut is a plain dial that completes in
|
||||
// whatever the client→relay RTT is (often <100ms)
|
||||
//
|
||||
// The 500ms head start is the minimum that empirically makes
|
||||
// same-LAN direct reliably beat relay, without penalizing
|
||||
// users who genuinely need the relay path.
|
||||
const DIRECT_HEAD_START: Duration = Duration::from_millis(500);
|
||||
let relay_fut = async move {
|
||||
tokio::time::sleep(DIRECT_HEAD_START).await;
|
||||
let conn =
|
||||
wzp_transport::connect(&relay_ep_for_fut, relay_addr, &relay_sni, relay_client_cfg)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("relay dial: {e}"))?;
|
||||
Ok::<_, anyhow::Error>(QuinnTransport::new(conn))
|
||||
};
|
||||
|
||||
// Phase 6: run both paths concurrently via tokio::spawn and
|
||||
// collect BOTH results. The old tokio::select! approach dropped
|
||||
// the loser, which meant the connect command couldn't negotiate
|
||||
// with the peer — it had to commit to whichever path won locally.
|
||||
//
|
||||
// Now we spawn both as tasks, wait for the first to complete
|
||||
// (that determines `local_winner`), then give the loser a short
|
||||
// grace period to also complete. The connect command gets a
|
||||
// RaceResult with both transports (when available) and uses the
|
||||
// Phase 6 MediaPathReport exchange to decide which one to
|
||||
// actually use for media.
|
||||
let smart_order = peer_candidates.smart_dial_order(own_reflexive.as_ref());
|
||||
tracing::info!(
|
||||
?role,
|
||||
raw_candidates = ?peer_candidates.dial_order(),
|
||||
filtered_candidates = ?smart_order,
|
||||
?own_reflexive,
|
||||
%relay_addr,
|
||||
"dual_path: racing direct vs relay"
|
||||
);
|
||||
|
||||
let mut direct_task = tokio::spawn(tokio::time::timeout(Duration::from_secs(4), direct_fut));
|
||||
let mut relay_task = tokio::spawn(async move {
|
||||
// Keep the 500ms head start so direct has a chance
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
tokio::time::timeout(Duration::from_secs(5), relay_fut).await
|
||||
});
|
||||
|
||||
// Wait for the first one to complete. This tells us the
|
||||
// local_winner — but we DON'T commit to it yet. Phase 6
|
||||
// negotiation decides the actual path.
|
||||
let (mut direct_result, mut relay_result): (
|
||||
Option<anyhow::Result<QuinnTransport>>,
|
||||
Option<anyhow::Result<QuinnTransport>>,
|
||||
) = (None, None);
|
||||
|
||||
let local_winner;
|
||||
|
||||
tokio::select! {
|
||||
biased;
|
||||
d = &mut direct_task => {
|
||||
match d {
|
||||
Ok(Ok(Ok(t))) => {
|
||||
tracing::info!("dual_path: direct completed first");
|
||||
direct_result = Some(Ok(t));
|
||||
local_winner = WinningPath::Direct;
|
||||
}
|
||||
Ok(Ok(Err(e))) => {
|
||||
tracing::warn!(error = %e, "dual_path: direct failed");
|
||||
direct_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
local_winner = WinningPath::Relay; // direct failed → relay is our only hope
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
tracing::warn!("dual_path: direct timed out (4s)");
|
||||
direct_result = Some(Err(anyhow::anyhow!("direct timeout")));
|
||||
local_winner = WinningPath::Relay;
|
||||
// Record timeout diag for candidates that were
|
||||
// still in-flight when the timeout fired.
|
||||
if let Ok(mut d) = diags_collector.lock() {
|
||||
let recorded_indices: std::collections::HashSet<usize> =
|
||||
d.iter().map(|diag| diag.index).collect();
|
||||
for (idx, addr) in smart_order.iter().enumerate() {
|
||||
if !recorded_indices.contains(&idx) {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: addr.to_string(),
|
||||
result: "timeout:4s".into(),
|
||||
elapsed_ms: Some(4000),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "dual_path: direct task panicked");
|
||||
direct_result = Some(Err(anyhow::anyhow!("direct task panic")));
|
||||
local_winner = WinningPath::Relay;
|
||||
}
|
||||
}
|
||||
}
|
||||
r = &mut relay_task => {
|
||||
match r {
|
||||
Ok(Ok(Ok(t))) => {
|
||||
tracing::info!("dual_path: relay completed first");
|
||||
relay_result = Some(Ok(t));
|
||||
local_winner = WinningPath::Relay;
|
||||
}
|
||||
Ok(Ok(Err(e))) => {
|
||||
tracing::warn!(error = %e, "dual_path: relay failed");
|
||||
relay_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
local_winner = WinningPath::Direct;
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
tracing::warn!("dual_path: relay timed out");
|
||||
relay_result = Some(Err(anyhow::anyhow!("relay timeout")));
|
||||
local_winner = WinningPath::Direct;
|
||||
}
|
||||
Err(e) => {
|
||||
relay_result = Some(Err(anyhow::anyhow!("relay task panic: {e}")));
|
||||
local_winner = WinningPath::Direct;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Give the loser a short grace period (1s) to also complete.
|
||||
// If it does, we have both transports for Phase 6 negotiation.
|
||||
// If it doesn't, we still proceed with just the winner.
|
||||
if direct_result.is_none() {
|
||||
match tokio::time::timeout(Duration::from_secs(1), direct_task).await {
|
||||
Ok(Ok(Ok(Ok(t)))) => {
|
||||
direct_result = Some(Ok(t));
|
||||
}
|
||||
Ok(Ok(Ok(Err(e)))) => {
|
||||
direct_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
}
|
||||
_ => {
|
||||
direct_result = Some(Err(anyhow::anyhow!("direct: no result in grace period")));
|
||||
// Fill timeout diags for candidates that never reported.
|
||||
if let Ok(mut d) = diags_collector.lock() {
|
||||
let recorded: std::collections::HashSet<usize> =
|
||||
d.iter().map(|diag| diag.index).collect();
|
||||
for (idx, addr) in smart_order.iter().enumerate() {
|
||||
if !recorded.contains(&idx) {
|
||||
d.push(CandidateDiag {
|
||||
index: idx,
|
||||
addr: addr.to_string(),
|
||||
result: "timeout:grace".into(),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if relay_result.is_none() {
|
||||
match tokio::time::timeout(Duration::from_secs(1), relay_task).await {
|
||||
Ok(Ok(Ok(Ok(t)))) => {
|
||||
relay_result = Some(Ok(t));
|
||||
}
|
||||
Ok(Ok(Ok(Err(e)))) => {
|
||||
relay_result = Some(Err(anyhow::anyhow!("{e}")));
|
||||
}
|
||||
_ => {
|
||||
relay_result = Some(Err(anyhow::anyhow!("relay: no result in grace period")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let direct_ok = direct_result.as_ref().map(|r| r.is_ok()).unwrap_or(false);
|
||||
let relay_ok = relay_result.as_ref().map(|r| r.is_ok()).unwrap_or(false);
|
||||
|
||||
tracing::info!(
|
||||
?local_winner,
|
||||
direct_ok,
|
||||
relay_ok,
|
||||
"dual_path: race finished, both results collected for Phase 6 negotiation"
|
||||
);
|
||||
|
||||
if !direct_ok && !relay_ok {
|
||||
return Err(anyhow::anyhow!(
|
||||
"both paths failed: no media transport available"
|
||||
));
|
||||
}
|
||||
|
||||
let _ = (direct_ep, relay_ep, ipv6_endpoint);
|
||||
|
||||
let candidate_diags = diags_collector
|
||||
.lock()
|
||||
.map(|d| d.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(RaceResult {
|
||||
direct_transport: direct_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||
relay_transport: relay_result.and_then(|r| r.ok()).map(|t| Arc::new(t)),
|
||||
local_winner,
|
||||
candidate_diags,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_all_types() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"192.168.1.10:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
// Order: local first, then mapped, then reflexive
|
||||
assert_eq!(order.len(), 4);
|
||||
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(order[1], "10.0.0.5:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(
|
||||
order[2],
|
||||
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
assert_eq!(order[3], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_no_mapped() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec!["192.168.1.10:4433".parse().unwrap()],
|
||||
mapped: None,
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 2);
|
||||
assert_eq!(order[0], "192.168.1.10:4433".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(order[1], "203.0.113.5:4433".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_only_mapped() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: None,
|
||||
local: vec![],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(
|
||||
order[0],
|
||||
"198.51.100.42:12345".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_dedup_mapped_equals_reflexive() {
|
||||
let addr: SocketAddr = "203.0.113.5:4433".parse().unwrap();
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some(addr),
|
||||
local: vec![],
|
||||
mapped: Some(addr), // same as reflexive
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
// Should be deduped to 1
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(order[0], addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_dial_order_dedup_mapped_in_local() {
|
||||
let addr: SocketAddr = "192.168.1.10:4433".parse().unwrap();
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: None,
|
||||
local: vec![addr],
|
||||
mapped: Some(addr), // same as a local addr
|
||||
};
|
||||
|
||||
let order = candidates.dial_order();
|
||||
assert_eq!(order.len(), 1);
|
||||
assert_eq!(order[0], addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_is_empty() {
|
||||
let empty = PeerCandidates::default();
|
||||
assert!(empty.is_empty());
|
||||
|
||||
let with_reflexive = PeerCandidates {
|
||||
reflexive: Some("1.2.3.4:5".parse().unwrap()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_reflexive.is_empty());
|
||||
|
||||
let with_local = PeerCandidates {
|
||||
local: vec!["10.0.0.1:5".parse().unwrap()],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_local.is_empty());
|
||||
|
||||
let with_mapped = PeerCandidates {
|
||||
mapped: Some("1.2.3.4:5".parse().unwrap()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!with_mapped.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_candidates_empty_dial_order() {
|
||||
let empty = PeerCandidates::default();
|
||||
assert!(empty.dial_order().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn winning_path_debug() {
|
||||
// Just verify Debug impl doesn't panic
|
||||
let _ = format!("{:?}", WinningPath::Direct);
|
||||
let _ = format!("{:?}", WinningPath::Relay);
|
||||
}
|
||||
|
||||
// ── smart_dial_order tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_same_network_includes_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("203.0.113.5:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"192.168.1.10:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
let own: SocketAddr = "203.0.113.5:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
// Same public IP → LAN candidates included
|
||||
assert!(order.contains(&"192.168.1.10:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"203.0.113.5:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_different_network_strips_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"172.16.81.126:4433".parse().unwrap(),
|
||||
"10.0.0.5:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
// Different public IP → LAN candidates stripped
|
||||
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
assert!(!order.contains(&"10.0.0.5:4433".parse().unwrap()));
|
||||
// Reflexive still included
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_strips_ipv6() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![
|
||||
"[2a0d:3344:692c::1]:4433".parse().unwrap(),
|
||||
"172.16.81.126:4433".parse().unwrap(),
|
||||
],
|
||||
mapped: None,
|
||||
};
|
||||
// Same network, but IPv6 should be stripped
|
||||
let own: SocketAddr = "150.228.49.65:5555".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert!(!order.iter().any(|a| a.is_ipv6()));
|
||||
assert!(order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_no_own_reflexive_strips_lan() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec!["172.16.81.126:4433".parse().unwrap()],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
// No own reflexive → can't determine same network → strip LAN
|
||||
let order = candidates.smart_dial_order(None);
|
||||
assert!(!order.contains(&"172.16.81.126:4433".parse().unwrap()));
|
||||
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smart_dial_order_mapped_always_included() {
|
||||
let candidates = PeerCandidates {
|
||||
reflexive: Some("150.228.49.65:4433".parse().unwrap()),
|
||||
local: vec![],
|
||||
mapped: Some("198.51.100.42:12345".parse().unwrap()),
|
||||
};
|
||||
let own: SocketAddr = "185.115.4.212:12345".parse().unwrap();
|
||||
let order = candidates.smart_dial_order(Some(&own));
|
||||
assert_eq!(order.len(), 2); // mapped + reflexive
|
||||
assert!(order.contains(&"198.51.100.42:12345".parse().unwrap()));
|
||||
assert!(order.contains(&"150.228.49.65:4433".parse().unwrap()));
|
||||
}
|
||||
}
|
||||
@@ -166,7 +166,7 @@ pub async fn run_echo_test(
|
||||
match tokio::time::timeout(Duration::from_millis(2), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
total_packets_received += 1;
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
if let Some(n) = decoder.decode_next(&mut pcm_buf) {
|
||||
@@ -184,7 +184,8 @@ pub async fn run_echo_test(
|
||||
let time_offset = start.elapsed().as_secs_f64();
|
||||
|
||||
// Compare sent vs received for this window
|
||||
let sent_start = (window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
||||
let sent_start =
|
||||
(window_idx as u64 * frames_per_window * FRAME_SAMPLES as u64) as usize;
|
||||
let sent_end = sent_start + (window_frames_sent as usize * FRAME_SAMPLES);
|
||||
let sent_window = if sent_end <= sent_pcm.len() {
|
||||
&sent_pcm[sent_start..sent_end]
|
||||
@@ -192,7 +193,9 @@ pub async fn run_echo_test(
|
||||
&sent_pcm[sent_start..]
|
||||
};
|
||||
|
||||
let recv_start = recv_pcm.len().saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
||||
let recv_start = recv_pcm
|
||||
.len()
|
||||
.saturating_sub(window_frames_received as usize * FRAME_SAMPLES);
|
||||
let recv_window = &recv_pcm[recv_start..];
|
||||
|
||||
let peak = recv_window.iter().map(|s| s.abs()).max().unwrap_or(0);
|
||||
@@ -256,7 +259,7 @@ pub async fn run_echo_test(
|
||||
match tokio::time::timeout(Duration::from_millis(100), transport.recv_media()).await {
|
||||
Ok(Ok(Some(pkt))) => {
|
||||
total_packets_received += 1;
|
||||
let is_repair = pkt.header.is_repair;
|
||||
let is_repair = pkt.header.is_repair();
|
||||
decoder.ingest(pkt);
|
||||
if !is_repair {
|
||||
decoder.decode_next(&mut pcm_buf);
|
||||
@@ -310,8 +313,14 @@ pub fn print_report(result: &EchoTestResult) {
|
||||
let status = if w.is_silent { " !" } else { " " };
|
||||
println!(
|
||||
"│ {:>3}{} │ {:>5.1}s │ {:>4} │ {:>4} │ {:>5.1}% │ {:>5.1} │ {:.3} │",
|
||||
w.index, status, w.time_offset_secs, w.frames_sent, w.frames_received,
|
||||
w.loss_pct, w.snr_db, w.correlation
|
||||
w.index,
|
||||
status,
|
||||
w.time_offset_secs,
|
||||
w.frames_sent,
|
||||
w.frames_received,
|
||||
w.loss_pct,
|
||||
w.snr_db,
|
||||
w.correlation
|
||||
);
|
||||
}
|
||||
println!("└───────┴─────────┴──────┴──────┴─────────┴───────┴───────┘");
|
||||
@@ -321,18 +330,28 @@ pub fn print_report(result: &EchoTestResult) {
|
||||
let first_half: Vec<_> = result.windows[..result.windows.len() / 2].to_vec();
|
||||
let second_half: Vec<_> = result.windows[result.windows.len() / 2..].to_vec();
|
||||
|
||||
let avg_loss_first = first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_loss_second = second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_corr_first = first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_corr_second = second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_loss_first =
|
||||
first_half.iter().map(|w| w.loss_pct).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_loss_second =
|
||||
second_half.iter().map(|w| w.loss_pct).sum::<f32>() / second_half.len() as f32;
|
||||
let avg_corr_first =
|
||||
first_half.iter().map(|w| w.correlation).sum::<f32>() / first_half.len() as f32;
|
||||
let avg_corr_second =
|
||||
second_half.iter().map(|w| w.correlation).sum::<f32>() / second_half.len() as f32;
|
||||
|
||||
println!();
|
||||
if avg_loss_second > avg_loss_first + 5.0 {
|
||||
println!("WARNING: Quality degradation detected!");
|
||||
println!(" Loss increased from {:.1}% to {:.1}% over time", avg_loss_first, avg_loss_second);
|
||||
println!(
|
||||
" Loss increased from {:.1}% to {:.1}% over time",
|
||||
avg_loss_first, avg_loss_second
|
||||
);
|
||||
}
|
||||
if avg_corr_second < avg_corr_first - 0.1 {
|
||||
println!("WARNING: Signal correlation dropped from {:.3} to {:.3}", avg_corr_first, avg_corr_second);
|
||||
println!(
|
||||
"WARNING: Signal correlation dropped from {:.3} to {:.3}",
|
||||
avg_corr_first, avg_corr_second
|
||||
);
|
||||
}
|
||||
if avg_loss_second <= avg_loss_first + 5.0 && avg_corr_second >= avg_corr_first - 0.1 {
|
||||
println!("Quality is STABLE over the test duration.");
|
||||
|
||||
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
213
crates/wzp-client/src/encrypted_transport.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
//! `EncryptingTransport` — wraps any `MediaTransport` with a `CryptoSession`.
|
||||
//!
|
||||
//! All outbound `send_media` calls encrypt the payload before handing off to
|
||||
//! the inner transport; all inbound `recv_media` calls decrypt after receiving.
|
||||
//! Signal, quality, and close are forwarded unchanged.
|
||||
//!
|
||||
//! The quality report travels in plaintext so the relay can make QoS decisions
|
||||
//! without being able to decrypt media content.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use wzp_proto::{
|
||||
CryptoSession, MediaHeader, MediaPacket, MediaTransport, PathQuality, SignalMessage,
|
||||
TransportError,
|
||||
};
|
||||
|
||||
/// Wraps a `MediaTransport` and applies AEAD encryption/decryption to media payloads.
|
||||
pub struct EncryptingTransport {
|
||||
inner: Arc<dyn MediaTransport>,
|
||||
session: Mutex<Box<dyn CryptoSession>>,
|
||||
}
|
||||
|
||||
impl EncryptingTransport {
|
||||
pub fn new(inner: Arc<dyn MediaTransport>, session: Box<dyn CryptoSession>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
session: Mutex::new(session),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MediaTransport for EncryptingTransport {
|
||||
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
packet.header.write_to(&mut header_bytes);
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap()
|
||||
.encrypt(&header_bytes, &packet.payload, &mut ciphertext)
|
||||
.map_err(|e| TransportError::Internal(format!("encrypt: {e}")))?;
|
||||
|
||||
let encrypted = MediaPacket {
|
||||
header: packet.header,
|
||||
payload: Bytes::from(ciphertext),
|
||||
quality_report: packet.quality_report.clone(),
|
||||
};
|
||||
self.inner.send_media(&encrypted).await
|
||||
}
|
||||
|
||||
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||
let packet = match self.inner.recv_media().await? {
|
||||
Some(p) => p,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut header_bytes = Vec::with_capacity(MediaHeader::WIRE_SIZE);
|
||||
packet.header.write_to(&mut header_bytes);
|
||||
|
||||
let mut plaintext = Vec::new();
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap()
|
||||
.decrypt(&header_bytes, &packet.payload, &mut plaintext)
|
||||
.map_err(|e| TransportError::Internal(format!("decrypt: {e}")))?;
|
||||
|
||||
Ok(Some(MediaPacket {
|
||||
header: packet.header,
|
||||
payload: Bytes::from(plaintext),
|
||||
quality_report: packet.quality_report,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn send_signal(&self, msg: &SignalMessage) -> Result<(), TransportError> {
|
||||
self.inner.send_signal(msg).await
|
||||
}
|
||||
|
||||
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||
self.inner.recv_signal().await
|
||||
}
|
||||
|
||||
fn path_quality(&self) -> PathQuality {
|
||||
self.inner.path_quality()
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), TransportError> {
|
||||
self.inner.close().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use wzp_crypto::ChaChaSession;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
struct LoopbackTransport {
|
||||
sent: StdMutex<Vec<MediaPacket>>,
|
||||
}
|
||||
|
||||
impl LoopbackTransport {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
sent: StdMutex::new(Vec::new()),
|
||||
})
|
||||
}
|
||||
fn take_sent(&self) -> Vec<MediaPacket> {
|
||||
self.sent.lock().unwrap().drain(..).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MediaTransport for LoopbackTransport {
|
||||
async fn send_media(&self, packet: &MediaPacket) -> Result<(), TransportError> {
|
||||
self.sent.lock().unwrap().push(packet.clone());
|
||||
Ok(())
|
||||
}
|
||||
async fn recv_media(&self) -> Result<Option<MediaPacket>, TransportError> {
|
||||
Ok(None)
|
||||
}
|
||||
async fn send_signal(&self, _msg: &SignalMessage) -> Result<(), TransportError> {
|
||||
Ok(())
|
||||
}
|
||||
async fn recv_signal(&self) -> Result<Option<SignalMessage>, TransportError> {
|
||||
Ok(None)
|
||||
}
|
||||
fn path_quality(&self) -> PathQuality {
|
||||
PathQuality::default()
|
||||
}
|
||||
async fn close(&self) -> Result<(), TransportError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_header(seq: u32) -> MediaHeader {
|
||||
MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn payload_is_encrypted_on_wire() {
|
||||
let key = [0x42u8; 32];
|
||||
let session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||
let loopback = LoopbackTransport::new();
|
||||
let enc = EncryptingTransport::new(loopback.clone(), session);
|
||||
|
||||
let header = make_header(1);
|
||||
let plaintext = b"secret audio frame";
|
||||
let pkt = MediaPacket {
|
||||
header,
|
||||
payload: Bytes::from_static(plaintext),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
enc.send_media(&pkt).await.unwrap();
|
||||
|
||||
let sent = loopback.take_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert_eq!(sent[0].header, header, "header must be preserved");
|
||||
assert_ne!(
|
||||
sent[0].payload.as_ref(),
|
||||
plaintext.as_ref(),
|
||||
"plaintext must not appear on wire"
|
||||
);
|
||||
// Ciphertext is longer by exactly the AEAD tag (16 bytes)
|
||||
assert_eq!(sent[0].payload.len(), plaintext.len() + 16);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn encrypt_then_decrypt_roundtrip() {
|
||||
let key = [0x42u8; 32];
|
||||
let send_session: Box<dyn CryptoSession> = Box::new(ChaChaSession::new(key));
|
||||
let mut recv_session = ChaChaSession::new(key);
|
||||
|
||||
let loopback = LoopbackTransport::new();
|
||||
let enc = EncryptingTransport::new(loopback.clone(), send_session);
|
||||
|
||||
let header = make_header(5);
|
||||
let plaintext = b"hello encrypted world";
|
||||
let pkt = MediaPacket {
|
||||
header,
|
||||
payload: Bytes::from_static(plaintext),
|
||||
quality_report: None,
|
||||
};
|
||||
|
||||
enc.send_media(&pkt).await.unwrap();
|
||||
|
||||
let sent = loopback.take_sent();
|
||||
let wire_pkt = &sent[0];
|
||||
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
let mut decrypted = Vec::new();
|
||||
recv_session
|
||||
.decrypt(&header_bytes, &wire_pkt.payload, &mut decrypted)
|
||||
.expect("decrypt should succeed with matching key");
|
||||
assert_eq!(&decrypted[..], plaintext);
|
||||
}
|
||||
}
|
||||
@@ -96,20 +96,54 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
SignalMessage::Hangup { .. } => CallSignalType::Hangup,
|
||||
SignalMessage::Rekey { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::QualityUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::LossRecoveryUpdate { .. } => CallSignalType::Offer, // reuse (telemetry)
|
||||
SignalMessage::Ping { .. } | SignalMessage::Pong { .. } => CallSignalType::Offer,
|
||||
SignalMessage::AuthToken { .. } => CallSignalType::Offer,
|
||||
SignalMessage::Hold => CallSignalType::Hold,
|
||||
SignalMessage::Unhold => CallSignalType::Unhold,
|
||||
SignalMessage::Mute => CallSignalType::Mute,
|
||||
SignalMessage::Unmute => CallSignalType::Unmute,
|
||||
SignalMessage::Hold { .. } => CallSignalType::Hold,
|
||||
SignalMessage::Unhold { .. } => CallSignalType::Unhold,
|
||||
SignalMessage::Mute { .. } => CallSignalType::Mute,
|
||||
SignalMessage::Unmute { .. } => CallSignalType::Unmute,
|
||||
SignalMessage::Transfer { .. } => CallSignalType::Transfer,
|
||||
SignalMessage::TransferAck => CallSignalType::Offer, // reuse
|
||||
SignalMessage::TransferAck { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::PresenceUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RouteQuery { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::TransportFeedback { .. } => CallSignalType::Offer, // reuse (BWE)
|
||||
SignalMessage::RouteResponse { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForward { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::SessionForwardAck { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::RoomUpdate { .. } => CallSignalType::Offer, // reuse
|
||||
SignalMessage::FederationHello { .. }
|
||||
| SignalMessage::GlobalRoomActive { .. }
|
||||
| SignalMessage::GlobalRoomInactive { .. } => CallSignalType::Offer, // relay-only
|
||||
SignalMessage::DirectCallOffer { .. } => CallSignalType::Offer,
|
||||
SignalMessage::DirectCallAnswer { .. } => CallSignalType::Answer,
|
||||
SignalMessage::CallSetup { .. } => CallSignalType::Offer, // relay-only
|
||||
SignalMessage::CallRinging { .. } => CallSignalType::Ringing,
|
||||
SignalMessage::RegisterPresence { .. } | SignalMessage::RegisterPresenceAck { .. } => {
|
||||
CallSignalType::Offer
|
||||
} // relay-only
|
||||
// NAT reflection is a client↔relay control exchange that
|
||||
// never crosses the featherChat bridge — if it ever reaches
|
||||
// this mapper something is wrong, but we still have to give
|
||||
// an answer. "Offer" is the generic catch-all.
|
||||
SignalMessage::Reflect | SignalMessage::ReflectResponse { .. } => CallSignalType::Offer, // control-plane
|
||||
// Phase 4 cross-relay forwarding envelope — strictly a
|
||||
// relay-to-relay message, never rides the featherChat
|
||||
// bridge. Catch-all mapping for completeness.
|
||||
SignalMessage::FederatedSignalForward { .. } => CallSignalType::Offer,
|
||||
SignalMessage::MediaPathReport { .. } => CallSignalType::Offer, // control-plane
|
||||
SignalMessage::CandidateUpdate { .. } => CallSignalType::IceCandidate, // mid-call re-gather
|
||||
SignalMessage::HardNatProbe { .. } => CallSignalType::IceCandidate, // hard NAT coordination
|
||||
SignalMessage::HardNatBirthdayStart { .. } => CallSignalType::IceCandidate, // birthday attack
|
||||
SignalMessage::UpgradeProposal { .. }
|
||||
| SignalMessage::UpgradeResponse { .. }
|
||||
| SignalMessage::UpgradeConfirm { .. }
|
||||
| SignalMessage::QualityCapability { .. } => CallSignalType::Offer, // quality negotiation
|
||||
SignalMessage::PresenceList { .. } => CallSignalType::Offer, // lobby presence
|
||||
SignalMessage::QualityDirective { .. } => CallSignalType::Offer, // relay-initiated
|
||||
SignalMessage::Nack { .. }
|
||||
| SignalMessage::PictureLossIndication { .. }
|
||||
| SignalMessage::SetPriorityMode { .. } => CallSignalType::Offer, // relay-initiated (video loss recovery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,14 +151,19 @@ pub fn signal_to_call_type(signal: &SignalMessage) -> CallSignalType {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::QualityProfile;
|
||||
use wzp_proto::default_signal_version;
|
||||
|
||||
#[test]
|
||||
fn payload_roundtrip() {
|
||||
let signal = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
};
|
||||
|
||||
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
|
||||
@@ -138,27 +177,52 @@ mod tests {
|
||||
#[test]
|
||||
fn signal_type_mapping() {
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
supported_profiles: vec![],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));
|
||||
|
||||
let hangup = SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&hangup), CallSignalType::Hangup));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&hangup),
|
||||
CallSignalType::Hangup
|
||||
));
|
||||
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Hold), CallSignalType::Hold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unhold), CallSignalType::Unhold));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Mute), CallSignalType::Mute));
|
||||
assert!(matches!(signal_to_call_type(&SignalMessage::Unmute), CallSignalType::Unmute));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Hold { version: default_signal_version() }),
|
||||
CallSignalType::Hold
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Unhold { version: default_signal_version() }),
|
||||
CallSignalType::Unhold
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Mute { version: default_signal_version() }),
|
||||
CallSignalType::Mute
|
||||
));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&SignalMessage::Unmute { version: default_signal_version() }),
|
||||
CallSignalType::Unmute
|
||||
));
|
||||
|
||||
let transfer = SignalMessage::Transfer {
|
||||
version: default_signal_version(),
|
||||
target_fingerprint: "abc".to_string(),
|
||||
relay_addr: None,
|
||||
};
|
||||
assert!(matches!(signal_to_call_type(&transfer), CallSignalType::Transfer));
|
||||
assert!(matches!(
|
||||
signal_to_call_type(&transfer),
|
||||
CallSignalType::Transfer
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,53 @@
|
||||
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
|
||||
|
||||
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
|
||||
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
use wzp_proto::{
|
||||
HangupReason, MediaTransport, QualityProfile, SignalMessage, default_signal_version,
|
||||
};
|
||||
|
||||
/// Errors that can occur during the client-side cryptographic handshake.
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeError {
|
||||
ConnectionClosed,
|
||||
ProtocolVersionMismatch { server_supported: Vec<u8> },
|
||||
UnexpectedSignal(&'static str),
|
||||
SignatureVerificationFailed,
|
||||
KeyDerivation(String),
|
||||
Transport(wzp_proto::TransportError),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HandshakeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => write!(f, "connection closed before receiving CallAnswer"),
|
||||
Self::ProtocolVersionMismatch { server_supported } => {
|
||||
write!(
|
||||
f,
|
||||
"protocol version mismatch: server supports {server_supported:?}"
|
||||
)
|
||||
}
|
||||
Self::UnexpectedSignal(expected) => write!(f, "expected CallAnswer, got {expected}"),
|
||||
Self::SignatureVerificationFailed => write!(f, "callee signature verification failed"),
|
||||
Self::KeyDerivation(msg) => write!(f, "key derivation failed: {msg}"),
|
||||
Self::Transport(e) => write!(f, "transport error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for HandshakeError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Transport(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wzp_proto::TransportError> for HandshakeError {
|
||||
fn from(e: wzp_proto::TransportError) -> Self {
|
||||
Self::Transport(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the client (caller) side of the cryptographic handshake.
|
||||
///
|
||||
@@ -17,7 +63,8 @@ use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
|
||||
pub async fn perform_handshake(
|
||||
transport: &dyn MediaTransport,
|
||||
seed: &[u8; 32],
|
||||
) -> Result<Box<dyn CryptoSession>, anyhow::Error> {
|
||||
alias: Option<&str>,
|
||||
) -> Result<Box<dyn CryptoSession>, HandshakeError> {
|
||||
// 1. Create key exchange from identity seed
|
||||
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
|
||||
let identity_pub = kx.identity_public_key();
|
||||
@@ -33,49 +80,69 @@ pub async fn perform_handshake(
|
||||
|
||||
// 4. Send CallOffer
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
supported_profiles: vec![
|
||||
QualityProfile::STUDIO_64K,
|
||||
QualityProfile::STUDIO_48K,
|
||||
QualityProfile::STUDIO_32K,
|
||||
QualityProfile::GOOD,
|
||||
QualityProfile::DEGRADED,
|
||||
QualityProfile::CATASTROPHIC,
|
||||
],
|
||||
alias: alias.map(|s| s.to_string()),
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
};
|
||||
transport.send_signal(&offer).await?;
|
||||
transport
|
||||
.send_signal(&offer)
|
||||
.await
|
||||
.map_err(HandshakeError::Transport)?;
|
||||
|
||||
// 5. Wait for CallAnswer
|
||||
let answer = transport
|
||||
.recv_signal()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallAnswer"))?;
|
||||
// 5. Wait for CallAnswer — 10s timeout guards against relay not responding.
|
||||
let answer = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(10),
|
||||
transport.recv_signal(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| HandshakeError::Transport(wzp_proto::TransportError::Timeout { ms: 10_000 }))?
|
||||
.map_err(HandshakeError::Transport)?
|
||||
.ok_or(HandshakeError::ConnectionClosed)?;
|
||||
|
||||
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) = match answer
|
||||
{
|
||||
SignalMessage::CallAnswer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
} => (identity_pub, ephemeral_pub, signature, chosen_profile),
|
||||
other => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"expected CallAnswer, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
))
|
||||
}
|
||||
};
|
||||
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) =
|
||||
match answer {
|
||||
SignalMessage::CallAnswer {
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature,
|
||||
chosen_profile,
|
||||
..
|
||||
} => (identity_pub, ephemeral_pub, signature, chosen_profile),
|
||||
SignalMessage::Hangup {
|
||||
reason: HangupReason::ProtocolVersionMismatch { server_supported },
|
||||
..
|
||||
} => {
|
||||
return Err(HandshakeError::ProtocolVersionMismatch { server_supported });
|
||||
}
|
||||
_ => {
|
||||
return Err(HandshakeError::UnexpectedSignal("CallAnswer"));
|
||||
}
|
||||
};
|
||||
|
||||
// 6. Verify callee's signature over (ephemeral_pub || "call-answer")
|
||||
let mut verify_data = Vec::with_capacity(32 + 11);
|
||||
verify_data.extend_from_slice(&callee_ephemeral_pub);
|
||||
verify_data.extend_from_slice(b"call-answer");
|
||||
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
|
||||
return Err(anyhow::anyhow!("callee signature verification failed"));
|
||||
return Err(HandshakeError::SignatureVerificationFailed);
|
||||
}
|
||||
|
||||
// 7. Derive session
|
||||
let session = kx.derive_session(&callee_ephemeral_pub)?;
|
||||
let session = kx
|
||||
.derive_session(&callee_ephemeral_pub)
|
||||
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
440
crates/wzp-client/src/ice_agent.rs
Normal file
440
crates/wzp-client/src/ice_agent.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Phase 8 (Tailscale-inspired): ICE agent for candidate lifecycle
|
||||
//! management and mid-call re-gathering.
|
||||
//!
|
||||
//! The `IceAgent` owns the state of all candidate discovery
|
||||
//! mechanisms (STUN, port mapping, host candidates) and provides:
|
||||
//!
|
||||
//! - `gather()`: initial candidate gathering during call setup
|
||||
//! - `re_gather()`: triggered on network change, produces a
|
||||
//! `CandidateUpdate` to send to the peer
|
||||
//! - `apply_peer_update()`: processes peer's candidate updates
|
||||
//!
|
||||
//! This is NOT a full ICE agent (RFC 8445). It's the Tailscale-style
|
||||
//! "gather all candidates, race them all in parallel, pick the
|
||||
//! winner" approach, adapted for QUIC transport.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use wzp_proto::{SignalMessage, default_signal_version};
|
||||
|
||||
use crate::dual_path::PeerCandidates;
|
||||
use crate::portmap;
|
||||
use crate::reflect;
|
||||
use crate::stun;
|
||||
|
||||
/// All candidates gathered for the local side.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CandidateSet {
|
||||
/// STUN-discovered server-reflexive address.
|
||||
pub reflexive: Option<SocketAddr>,
|
||||
/// LAN host candidates from local interfaces.
|
||||
pub local: Vec<SocketAddr>,
|
||||
/// Port-mapped address from NAT-PMP/PCP/UPnP.
|
||||
pub mapped: Option<SocketAddr>,
|
||||
/// Generation counter (monotonically increasing per call).
|
||||
pub generation: u32,
|
||||
}
|
||||
|
||||
/// Configuration for the ICE agent.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IceAgentConfig {
|
||||
/// STUN servers to use for reflexive discovery.
|
||||
pub stun_config: stun::StunConfig,
|
||||
/// Whether to attempt port mapping.
|
||||
pub enable_portmap: bool,
|
||||
/// Timeout for each discovery mechanism.
|
||||
pub gather_timeout: Duration,
|
||||
/// The QUIC endpoint's local port (for host candidate pairing).
|
||||
pub local_v4_port: u16,
|
||||
/// Optional IPv6 port.
|
||||
pub local_v6_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl Default for IceAgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stun_config: stun::StunConfig::default(),
|
||||
enable_portmap: true,
|
||||
gather_timeout: Duration::from_secs(3),
|
||||
local_v4_port: 0,
|
||||
local_v6_port: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ICE agent managing candidate lifecycle.
|
||||
pub struct IceAgent {
|
||||
config: IceAgentConfig,
|
||||
generation: AtomicU32,
|
||||
call_id: String,
|
||||
/// Last-seen peer generation (to filter stale updates).
|
||||
peer_generation: AtomicU32,
|
||||
}
|
||||
|
||||
impl IceAgent {
|
||||
pub fn new(call_id: String, config: IceAgentConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
generation: AtomicU32::new(0),
|
||||
call_id,
|
||||
peer_generation: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initial candidate gathering. Runs all discovery mechanisms
|
||||
/// in parallel and returns the full candidate set.
|
||||
pub async fn gather(&self) -> CandidateSet {
|
||||
let generation = self.generation.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// Run STUN + port mapping + host candidates in parallel.
|
||||
let stun_fut = stun::discover_reflexive(&self.config.stun_config);
|
||||
let portmap_fut = async {
|
||||
if self.config.enable_portmap && self.config.local_v4_port > 0 {
|
||||
portmap::acquire_port_mapping(self.config.local_v4_port, None)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (stun_result, portmap_result) = tokio::join!(
|
||||
tokio::time::timeout(self.config.gather_timeout, stun_fut),
|
||||
tokio::time::timeout(self.config.gather_timeout, portmap_fut),
|
||||
);
|
||||
|
||||
let reflexive = stun_result.ok().and_then(|r| r.ok());
|
||||
let mapped = portmap_result.ok().flatten().map(|m| m.external_addr);
|
||||
let local =
|
||||
reflect::local_host_candidates(self.config.local_v4_port, self.config.local_v6_port);
|
||||
|
||||
tracing::info!(
|
||||
generation,
|
||||
reflexive = ?reflexive,
|
||||
mapped = ?mapped,
|
||||
local_count = local.len(),
|
||||
"ice_agent: gathered candidates"
|
||||
);
|
||||
|
||||
CandidateSet {
|
||||
reflexive,
|
||||
local,
|
||||
mapped,
|
||||
generation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Re-gather candidates after a network change. Increments the
|
||||
/// generation counter and returns a `CandidateUpdate` signal
|
||||
/// message to send to the peer.
|
||||
pub async fn re_gather(&self) -> (CandidateSet, SignalMessage) {
|
||||
let candidates = self.gather().await;
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: self.call_id.clone(),
|
||||
reflexive_addr: candidates.reflexive.map(|a| a.to_string()),
|
||||
local_addrs: candidates.local.iter().map(|a| a.to_string()).collect(),
|
||||
mapped_addr: candidates.mapped.map(|a| a.to_string()),
|
||||
generation: candidates.generation,
|
||||
};
|
||||
|
||||
(candidates, update)
|
||||
}
|
||||
|
||||
/// Process a peer's candidate update. Returns `Some(PeerCandidates)`
|
||||
/// if the update is newer than the last-seen generation, `None`
|
||||
/// if it's stale.
|
||||
pub fn apply_peer_update(&self, update: &SignalMessage) -> Option<PeerCandidates> {
|
||||
let (reflexive_addr, local_addrs, mapped_addr, generation) = match update {
|
||||
SignalMessage::CandidateUpdate {
|
||||
reflexive_addr,
|
||||
local_addrs,
|
||||
mapped_addr,
|
||||
generation,
|
||||
..
|
||||
} => (reflexive_addr, local_addrs, mapped_addr, *generation),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
// Only accept if newer than last-seen generation.
|
||||
let prev = self.peer_generation.fetch_max(generation, Ordering::AcqRel);
|
||||
if generation <= prev {
|
||||
tracing::debug!(
|
||||
generation,
|
||||
prev,
|
||||
"ice_agent: ignoring stale CandidateUpdate"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let reflexive = reflexive_addr.as_deref().and_then(|s| s.parse().ok());
|
||||
let local: Vec<SocketAddr> = local_addrs.iter().filter_map(|s| s.parse().ok()).collect();
|
||||
let mapped = mapped_addr.as_deref().and_then(|s| s.parse().ok());
|
||||
|
||||
tracing::info!(
|
||||
generation,
|
||||
reflexive = ?reflexive,
|
||||
mapped = ?mapped,
|
||||
local_count = local.len(),
|
||||
"ice_agent: applied peer candidate update"
|
||||
);
|
||||
|
||||
Some(PeerCandidates {
|
||||
reflexive,
|
||||
local,
|
||||
mapped,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current generation counter.
|
||||
pub fn generation(&self) -> u32 {
|
||||
self.generation.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_rejects_stale() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
|
||||
// First update (gen=1) should succeed.
|
||||
let update1 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
local_addrs: vec!["192.168.1.10:4433".into()],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
let result = agent.apply_peer_update(&update1);
|
||||
assert!(result.is_some());
|
||||
let candidates = result.unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("203.0.113.5:4433".parse().unwrap())
|
||||
);
|
||||
assert_eq!(candidates.local.len(), 1);
|
||||
|
||||
// Same generation (gen=1) should be rejected.
|
||||
let update1b = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("198.51.100.9:4433".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
assert!(agent.apply_peer_update(&update1b).is_none());
|
||||
|
||||
// Older generation (gen=0) should be rejected.
|
||||
let update0 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("10.0.0.1:4433".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 0,
|
||||
};
|
||||
assert!(agent.apply_peer_update(&update0).is_none());
|
||||
|
||||
// Newer generation (gen=2) should succeed.
|
||||
let update2 = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("198.51.100.9:5555".into()),
|
||||
local_addrs: vec![],
|
||||
mapped_addr: Some("203.0.113.5:12345".into()),
|
||||
generation: 2,
|
||||
};
|
||||
let result = agent.apply_peer_update(&update2);
|
||||
assert!(result.is_some());
|
||||
let candidates = result.unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("198.51.100.9:5555".parse().unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
candidates.mapped,
|
||||
Some("203.0.113.5:12345".parse().unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_wrong_signal_returns_none() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
let wrong = SignalMessage::Reflect;
|
||||
assert!(agent.apply_peer_update(&wrong).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generation_increments() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
assert_eq!(agent.generation(), 0);
|
||||
// Simulate what gather() does internally
|
||||
let g1 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||
assert_eq!(g1, 0);
|
||||
assert_eq!(agent.generation(), 1);
|
||||
let g2 = agent.generation.fetch_add(1, Ordering::Relaxed);
|
||||
assert_eq!(g2, 1);
|
||||
assert_eq!(agent.generation(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_parses_all_fields() {
|
||||
let agent = IceAgent::new("test-call".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test-call".into(),
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
local_addrs: vec!["192.168.1.10:4433".into(), "10.0.0.5:4433".into()],
|
||||
mapped_addr: Some("198.51.100.42:12345".into()),
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert_eq!(
|
||||
candidates.reflexive,
|
||||
Some("203.0.113.5:4433".parse().unwrap())
|
||||
);
|
||||
assert_eq!(candidates.local.len(), 2);
|
||||
assert_eq!(
|
||||
candidates.local[0],
|
||||
"192.168.1.10:4433".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
candidates.mapped,
|
||||
Some("198.51.100.42:12345".parse().unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_handles_empty_fields() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test".into(),
|
||||
reflexive_addr: None,
|
||||
local_addrs: vec![],
|
||||
mapped_addr: None,
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert!(candidates.reflexive.is_none());
|
||||
assert!(candidates.local.is_empty());
|
||||
assert!(candidates.mapped.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_peer_update_skips_unparseable_addrs() {
|
||||
let agent = IceAgent::new("test".into(), IceAgentConfig::default());
|
||||
|
||||
let update = SignalMessage::CandidateUpdate {
|
||||
version: default_signal_version(),
|
||||
call_id: "test".into(),
|
||||
reflexive_addr: Some("not-an-addr".into()),
|
||||
local_addrs: vec![
|
||||
"192.168.1.10:4433".into(),
|
||||
"garbage".into(),
|
||||
"10.0.0.5:4433".into(),
|
||||
],
|
||||
mapped_addr: Some("also-bad".into()),
|
||||
generation: 1,
|
||||
};
|
||||
|
||||
let candidates = agent.apply_peer_update(&update).unwrap();
|
||||
assert!(candidates.reflexive.is_none()); // unparseable
|
||||
assert_eq!(candidates.local.len(), 2); // garbage filtered
|
||||
assert!(candidates.mapped.is_none()); // unparseable
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = IceAgentConfig::default();
|
||||
assert!(cfg.enable_portmap);
|
||||
assert!(cfg.gather_timeout.as_secs() > 0);
|
||||
assert!(!cfg.stun_config.servers.is_empty());
|
||||
assert_eq!(cfg.local_v4_port, 0);
|
||||
assert!(cfg.local_v6_port.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn gather_returns_candidates_even_with_no_stun() {
|
||||
// With default config (port 0 = no portmap, STUN will timeout
|
||||
// quickly on loopback), gather should still return host candidates.
|
||||
let agent = IceAgent::new(
|
||||
"test".into(),
|
||||
IceAgentConfig {
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec![], // no servers = quick failure
|
||||
timeout: Duration::from_millis(100),
|
||||
},
|
||||
enable_portmap: false,
|
||||
gather_timeout: Duration::from_millis(200),
|
||||
local_v4_port: 12345,
|
||||
local_v6_port: None,
|
||||
},
|
||||
);
|
||||
|
||||
let candidates = agent.gather().await;
|
||||
assert_eq!(candidates.generation, 0);
|
||||
// Reflexive should be None (no STUN servers)
|
||||
assert!(candidates.reflexive.is_none());
|
||||
// Mapped should be None (portmap disabled)
|
||||
assert!(candidates.mapped.is_none());
|
||||
// Local candidates depend on the machine's interfaces
|
||||
// but gather() should not panic.
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn re_gather_produces_signal_message() {
|
||||
let agent = IceAgent::new(
|
||||
"call-42".into(),
|
||||
IceAgentConfig {
|
||||
stun_config: stun::StunConfig {
|
||||
servers: vec![],
|
||||
timeout: Duration::from_millis(50),
|
||||
},
|
||||
enable_portmap: false,
|
||||
gather_timeout: Duration::from_millis(100),
|
||||
local_v4_port: 4433,
|
||||
local_v6_port: None,
|
||||
},
|
||||
);
|
||||
|
||||
let (candidates, signal) = agent.re_gather().await;
|
||||
assert_eq!(candidates.generation, 0);
|
||||
|
||||
match signal {
|
||||
SignalMessage::CandidateUpdate {
|
||||
call_id,
|
||||
generation,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call-42");
|
||||
assert_eq!(generation, 0);
|
||||
}
|
||||
_ => panic!("expected CandidateUpdate"),
|
||||
}
|
||||
|
||||
// Second re_gather increments generation
|
||||
let (candidates2, signal2) = agent.re_gather().await;
|
||||
assert_eq!(candidates2.generation, 1);
|
||||
match signal2 {
|
||||
SignalMessage::CandidateUpdate { generation, .. } => {
|
||||
assert_eq!(generation, 1);
|
||||
}
|
||||
_ => panic!("expected CandidateUpdate"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,16 +8,84 @@
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio_io;
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio_ring;
|
||||
// VoiceProcessingIO is an Apple Core Audio API — only compile the module
|
||||
// when the `vpio` feature is on AND we're targeting macOS. Enabling the
|
||||
// feature on Windows/Linux was previously silently broken.
|
||||
#[cfg(all(feature = "vpio", target_os = "macos"))]
|
||||
pub mod audio_vpio;
|
||||
// WASAPI-direct capture with Windows's OS-level AEC (AudioCategory_Communications).
|
||||
// Only compiled when `windows-aec` feature is on AND target is Windows. The
|
||||
// `windows` dependency is itself gated to Windows in Cargo.toml, so enabling
|
||||
// this feature on non-Windows targets is a no-op.
|
||||
#[cfg(all(feature = "windows-aec", target_os = "windows"))]
|
||||
pub mod audio_wasapi;
|
||||
// WebRTC AEC3 (Audio Processing Module) wrapper around CPAL capture + playback
|
||||
// on Linux. Only compiled when `linux-aec` feature is on AND target is Linux.
|
||||
// The webrtc-audio-processing dep is itself gated to Linux in Cargo.toml.
|
||||
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
||||
pub mod audio_linux_aec;
|
||||
pub mod bench;
|
||||
pub mod birthday;
|
||||
pub mod call;
|
||||
pub mod encrypted_transport;
|
||||
pub mod drift_test;
|
||||
pub mod dual_path;
|
||||
pub mod echo_test;
|
||||
pub mod featherchat;
|
||||
pub mod handshake;
|
||||
pub mod ice_agent;
|
||||
pub mod metrics;
|
||||
pub mod netcheck;
|
||||
pub mod portmap;
|
||||
pub mod reflect;
|
||||
pub mod relay_map;
|
||||
pub mod stun;
|
||||
pub mod sweep;
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
pub use audio_io::{AudioCapture, AudioPlayback};
|
||||
// AudioPlayback: three possible backends depending on feature flags.
|
||||
// 1. Default CPAL (`audio_io::AudioPlayback`) — baseline on every platform.
|
||||
// 2. Linux AEC (`audio_linux_aec::LinuxAecPlayback`) — CPAL + WebRTC APM
|
||||
// render-side tee, so echo from speakers gets cancelled from the mic.
|
||||
//
|
||||
// On macOS and Windows we always use the default CPAL playback because:
|
||||
// - macOS: VoiceProcessingIO handles AEC at the capture side (Apple's
|
||||
// native hardware AEC uses its own reference signal handling).
|
||||
// - Windows: WASAPI AudioCategory_Communications AEC uses the system
|
||||
// render mix as reference — no per-process plumbing needed.
|
||||
//
|
||||
// Linux is the only platform where the in-app approach is necessary, so
|
||||
// the AEC playback path is gated to target_os = "linux".
|
||||
|
||||
#[cfg(all(
|
||||
feature = "audio",
|
||||
any(not(feature = "linux-aec"), not(target_os = "linux"))
|
||||
))]
|
||||
pub use audio_io::AudioPlayback;
|
||||
|
||||
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
||||
pub use audio_linux_aec::LinuxAecPlayback as AudioPlayback;
|
||||
|
||||
// AudioCapture: three possible backends depending on feature flags.
|
||||
// 1. Default CPAL (`audio_io::AudioCapture`) — baseline on every platform.
|
||||
// 2. Windows AEC (`audio_wasapi::WasapiAudioCapture`) — direct WASAPI
|
||||
// with AudioCategory_Communications, OS APO chain does AEC.
|
||||
// 3. Linux AEC (`audio_linux_aec::LinuxAecCapture`) — CPAL + WebRTC APM
|
||||
// capture-side echo cancellation using the playback tee as reference.
|
||||
// All three expose the same public API (`start`, `ring`, `stop`, `Drop`).
|
||||
|
||||
#[cfg(all(
|
||||
feature = "audio",
|
||||
any(not(feature = "windows-aec"), not(target_os = "windows")),
|
||||
any(not(feature = "linux-aec"), not(target_os = "linux"))
|
||||
))]
|
||||
pub use audio_io::AudioCapture;
|
||||
|
||||
#[cfg(all(feature = "windows-aec", target_os = "windows"))]
|
||||
pub use audio_wasapi::WasapiAudioCapture as AudioCapture;
|
||||
|
||||
#[cfg(all(feature = "linux-aec", target_os = "linux"))]
|
||||
pub use audio_linux_aec::LinuxAecCapture as AudioCapture;
|
||||
pub use call::{CallConfig, CallDecoder, CallEncoder};
|
||||
pub use handshake::perform_handshake;
|
||||
|
||||
@@ -178,7 +178,10 @@ mod tests {
|
||||
|
||||
// Immediate second write should be skipped (60s interval).
|
||||
let second = writer.maybe_write(&snap).unwrap();
|
||||
assert!(!second, "second write should be skipped — interval not elapsed");
|
||||
assert!(
|
||||
!second,
|
||||
"second write should be skipped — interval not elapsed"
|
||||
);
|
||||
|
||||
// Clean up.
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
537
crates/wzp-client/src/netcheck.rs
Normal file
537
crates/wzp-client/src/netcheck.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
//! Phase 8 (Tailscale-inspired): Comprehensive network diagnostic.
|
||||
//!
|
||||
//! Probes STUN servers, relay infrastructure, port mapping
|
||||
//! capabilities, IPv6 reachability, and NAT hairpinning in parallel
|
||||
//! to produce a `NetcheckReport` that captures the client's network
|
||||
//! environment at a point in time.
|
||||
//!
|
||||
//! Used for:
|
||||
//! - Troubleshooting connectivity issues
|
||||
//! - Automatic relay selection (Phase 5)
|
||||
//! - Pre-call NAT assessment
|
||||
//! - Quality prediction
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::portmap::{self, PortMapProtocol};
|
||||
use crate::reflect::{self, NatType};
|
||||
use crate::stun::{self, StunConfig};
|
||||
|
||||
/// Complete network diagnostic report.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NetcheckReport {
|
||||
/// NAT type classification (from combined STUN + relay probes).
|
||||
pub nat_type: NatType,
|
||||
/// Server-reflexive address (consensus from probes).
|
||||
pub reflexive_addr: Option<String>,
|
||||
/// Whether IPv4 connectivity is available.
|
||||
pub ipv4_reachable: bool,
|
||||
/// Whether IPv6 connectivity is available.
|
||||
pub ipv6_reachable: bool,
|
||||
/// Whether the NAT supports hairpinning (loopback to own
|
||||
/// reflexive address).
|
||||
pub hairpin_works: Option<bool>,
|
||||
/// Which port mapping protocol is available (if any).
|
||||
pub port_mapping: Option<PortMapProtocol>,
|
||||
/// Per-relay latency measurements.
|
||||
pub relay_latencies: Vec<RelayLatency>,
|
||||
/// Preferred relay (lowest latency).
|
||||
pub preferred_relay: Option<String>,
|
||||
/// STUN latency to first responding server (ms).
|
||||
pub stun_latency_ms: Option<u32>,
|
||||
/// Whether UPnP is available on the gateway.
|
||||
pub upnp_available: bool,
|
||||
/// Whether PCP is available on the gateway.
|
||||
pub pcp_available: bool,
|
||||
/// Whether NAT-PMP is available on the gateway.
|
||||
pub nat_pmp_available: bool,
|
||||
/// Default gateway address.
|
||||
pub gateway: Option<String>,
|
||||
/// Total time taken for the diagnostic (ms).
|
||||
pub duration_ms: u32,
|
||||
/// Individual STUN probe results.
|
||||
pub stun_probes: Vec<reflect::NatProbeResult>,
|
||||
/// NAT port allocation pattern (sequential vs random).
|
||||
pub port_allocation: Option<stun::PortAllocation>,
|
||||
}
|
||||
|
||||
/// Latency to a specific relay.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RelayLatency {
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub rtt_ms: Option<u32>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Configuration for the netcheck run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetcheckConfig {
|
||||
/// STUN servers to probe.
|
||||
pub stun_config: StunConfig,
|
||||
/// Relay servers to probe (name, address pairs).
|
||||
pub relays: Vec<(String, SocketAddr)>,
|
||||
/// Per-probe timeout.
|
||||
pub timeout: Duration,
|
||||
/// Whether to test port mapping.
|
||||
pub test_portmap: bool,
|
||||
/// Whether to test IPv6.
|
||||
pub test_ipv6: bool,
|
||||
/// Local port for port mapping test (0 = skip).
|
||||
pub local_port: u16,
|
||||
}
|
||||
|
||||
impl Default for NetcheckConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stun_config: StunConfig::default(),
|
||||
relays: Vec::new(),
|
||||
timeout: Duration::from_secs(5),
|
||||
test_portmap: true,
|
||||
test_ipv6: true,
|
||||
local_port: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a comprehensive network diagnostic.
|
||||
///
|
||||
/// Probes run in parallel for speed — the total time is bounded
|
||||
/// by the slowest individual probe, not the sum.
|
||||
pub async fn run_netcheck(config: &NetcheckConfig) -> NetcheckReport {
|
||||
let start = Instant::now();
|
||||
|
||||
// Run all probes in parallel.
|
||||
let stun_fut = stun::probe_stun_servers(&config.stun_config);
|
||||
let relay_fut = probe_relays(&config.relays, config.timeout);
|
||||
let portmap_fut = probe_portmap(config.test_portmap, config.local_port);
|
||||
let gateway_fut = portmap::default_gateway();
|
||||
let ipv6_fut = test_ipv6(config.test_ipv6, config.timeout);
|
||||
let port_alloc_fut = stun::detect_port_allocation(&config.stun_config);
|
||||
|
||||
let (
|
||||
stun_probes,
|
||||
relay_latencies,
|
||||
portmap_result,
|
||||
gateway_result,
|
||||
ipv6_reachable,
|
||||
port_alloc_result,
|
||||
) = tokio::join!(
|
||||
stun_fut,
|
||||
relay_fut,
|
||||
portmap_fut,
|
||||
gateway_result_fut(gateway_fut),
|
||||
ipv6_fut,
|
||||
port_alloc_fut
|
||||
);
|
||||
|
||||
// Classify NAT from STUN probes.
|
||||
let (nat_type, consensus_addr) = reflect::classify_nat(&stun_probes);
|
||||
|
||||
// Determine STUN latency (first successful probe).
|
||||
let stun_latency_ms = stun_probes.iter().filter_map(|p| p.latency_ms).min();
|
||||
|
||||
// IPv4 reachable if any STUN probe succeeded.
|
||||
let ipv4_reachable = stun_probes.iter().any(|p| p.observed_addr.is_some());
|
||||
|
||||
// Preferred relay = lowest RTT.
|
||||
let preferred_relay = relay_latencies
|
||||
.iter()
|
||||
.filter_map(|r| r.rtt_ms.map(|rtt| (r.name.clone(), rtt)))
|
||||
.min_by_key(|(_, rtt)| *rtt)
|
||||
.map(|(name, _)| name);
|
||||
|
||||
// Port mapping availability.
|
||||
let (port_mapping, nat_pmp_available, pcp_available, upnp_available) = match portmap_result {
|
||||
Some(mapping) => {
|
||||
let proto = mapping.protocol;
|
||||
(
|
||||
Some(proto),
|
||||
proto == PortMapProtocol::NatPmp,
|
||||
proto == PortMapProtocol::Pcp,
|
||||
proto == PortMapProtocol::UPnP,
|
||||
)
|
||||
}
|
||||
None => (None, false, false, false),
|
||||
};
|
||||
|
||||
let gateway = match gateway_result {
|
||||
Ok(gw) => Some(gw.to_string()),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
NetcheckReport {
|
||||
nat_type,
|
||||
reflexive_addr: consensus_addr,
|
||||
ipv4_reachable,
|
||||
ipv6_reachable,
|
||||
hairpin_works: None, // TODO: implement hairpin test
|
||||
port_mapping,
|
||||
relay_latencies,
|
||||
preferred_relay,
|
||||
stun_latency_ms,
|
||||
upnp_available,
|
||||
pcp_available,
|
||||
nat_pmp_available,
|
||||
gateway,
|
||||
duration_ms: start.elapsed().as_millis() as u32,
|
||||
stun_probes,
|
||||
port_allocation: Some(port_alloc_result.allocation),
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe relay latencies via reflect.
|
||||
async fn probe_relays(relays: &[(String, SocketAddr)], timeout: Duration) -> Vec<RelayLatency> {
|
||||
if relays.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let timeout_ms = timeout.as_millis() as u64;
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
for (name, addr) in relays {
|
||||
let name = name.clone();
|
||||
let addr = *addr;
|
||||
set.spawn(async move {
|
||||
let start = Instant::now();
|
||||
match reflect::probe_reflect_addr(addr, timeout_ms, None).await {
|
||||
Ok((_observed, _latency)) => RelayLatency {
|
||||
name,
|
||||
addr: addr.to_string(),
|
||||
rtt_ms: Some(start.elapsed().as_millis() as u32),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => RelayLatency {
|
||||
name,
|
||||
addr: addr.to_string(),
|
||||
rtt_ms: None,
|
||||
error: Some(e),
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(relays.len());
|
||||
while let Some(join_result) = set.join_next().await {
|
||||
match join_result {
|
||||
Ok(r) => results.push(r),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by RTT (lowest first).
|
||||
results.sort_by_key(|r| r.rtt_ms.unwrap_or(u32::MAX));
|
||||
results
|
||||
}
|
||||
|
||||
/// Attempt port mapping and return the mapping if successful.
|
||||
async fn probe_portmap(enabled: bool, local_port: u16) -> Option<portmap::PortMapping> {
|
||||
if !enabled || local_port == 0 {
|
||||
return None;
|
||||
}
|
||||
portmap::acquire_port_mapping(local_port, None).await.ok()
|
||||
}
|
||||
|
||||
/// Wrap the gateway future to handle the Result.
|
||||
async fn gateway_result_fut(
|
||||
fut: impl std::future::Future<Output = Result<std::net::Ipv4Addr, portmap::PortMapError>>,
|
||||
) -> Result<std::net::Ipv4Addr, portmap::PortMapError> {
|
||||
fut.await
|
||||
}
|
||||
|
||||
/// Test IPv6 connectivity by attempting to bind and send on an IPv6 socket.
|
||||
async fn test_ipv6(enabled: bool, timeout: Duration) -> bool {
|
||||
if !enabled {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to resolve and connect to an IPv6 STUN server.
|
||||
let result = tokio::time::timeout(timeout, async {
|
||||
let sock = tokio::net::UdpSocket::bind("[::]:0").await.ok()?;
|
||||
// Try Google's IPv6 STUN — if DNS resolves to an AAAA record
|
||||
// and we can send a packet, IPv6 is working.
|
||||
let addr = stun::resolve_stun_server("stun.l.google.com:19302")
|
||||
.await
|
||||
.ok()?;
|
||||
if addr.is_ipv6() {
|
||||
sock.send_to(&[0u8; 1], addr).await.ok()?;
|
||||
Some(true)
|
||||
} else {
|
||||
// Server resolved to IPv4 — try binding to [::] at least
|
||||
Some(false)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(true)) => true,
|
||||
_ => {
|
||||
// Fallback: can we at least bind an IPv6 socket?
|
||||
tokio::net::UdpSocket::bind("[::]:0").await.is_ok()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a netcheck report as a human-readable string.
|
||||
pub fn format_report(report: &NetcheckReport) -> String {
|
||||
let mut out = String::new();
|
||||
|
||||
out.push_str(&format!("=== WarzonePhone Netcheck ===\n\n"));
|
||||
out.push_str(&format!("NAT Type: {:?}\n", report.nat_type));
|
||||
out.push_str(&format!(
|
||||
"Reflexive Addr: {}\n",
|
||||
report.reflexive_addr.as_deref().unwrap_or("(unknown)")
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"IPv4: {}\n",
|
||||
if report.ipv4_reachable { "yes" } else { "no" }
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"IPv6: {}\n",
|
||||
if report.ipv6_reachable { "yes" } else { "no" }
|
||||
));
|
||||
out.push_str(&format!(
|
||||
"Gateway: {}\n",
|
||||
report.gateway.as_deref().unwrap_or("(unknown)")
|
||||
));
|
||||
|
||||
if let Some(ref alloc) = report.port_allocation {
|
||||
out.push_str(&format!("Port Alloc: {alloc}\n"));
|
||||
}
|
||||
|
||||
out.push_str(&format!("\n--- Port Mapping ---\n"));
|
||||
out.push_str(&format!(
|
||||
"NAT-PMP: {} PCP: {} UPnP: {}\n",
|
||||
if report.nat_pmp_available {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
},
|
||||
if report.pcp_available { "yes" } else { "no" },
|
||||
if report.upnp_available { "yes" } else { "no" },
|
||||
));
|
||||
if let Some(proto) = &report.port_mapping {
|
||||
out.push_str(&format!("Active mapping: {:?}\n", proto));
|
||||
}
|
||||
|
||||
if !report.stun_probes.is_empty() {
|
||||
out.push_str(&format!("\n--- STUN Probes ---\n"));
|
||||
for p in &report.stun_probes {
|
||||
out.push_str(&format!(
|
||||
" {} → {} ({}ms){}\n",
|
||||
p.relay_name,
|
||||
p.observed_addr.as_deref().unwrap_or("failed"),
|
||||
p.latency_ms
|
||||
.map(|ms| ms.to_string())
|
||||
.unwrap_or_else(|| "-".into()),
|
||||
p.error
|
||||
.as_ref()
|
||||
.map(|e| format!(" [{e}]"))
|
||||
.unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if !report.relay_latencies.is_empty() {
|
||||
out.push_str(&format!("\n--- Relay Latencies ---\n"));
|
||||
for r in &report.relay_latencies {
|
||||
out.push_str(&format!(
|
||||
" {} ({}) → {}ms{}\n",
|
||||
r.name,
|
||||
r.addr,
|
||||
r.rtt_ms
|
||||
.map(|ms| ms.to_string())
|
||||
.unwrap_or_else(|| "-".into()),
|
||||
r.error
|
||||
.as_ref()
|
||||
.map(|e| format!(" [{e}]"))
|
||||
.unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
if let Some(ref pref) = report.preferred_relay {
|
||||
out.push_str(&format!(" Preferred: {pref}\n"));
|
||||
}
|
||||
}
|
||||
|
||||
out.push_str(&format!("\nCompleted in {}ms\n", report.duration_ms));
|
||||
out
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_has_stun_servers() {
|
||||
let config = NetcheckConfig::default();
|
||||
assert!(!config.stun_config.servers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_produces_output() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Cone,
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: None,
|
||||
relay_latencies: vec![RelayLatency {
|
||||
name: "relay-1".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(25),
|
||||
error: None,
|
||||
}],
|
||||
preferred_relay: Some("relay-1".into()),
|
||||
stun_latency_ms: Some(15),
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: Some("192.168.1.1".into()),
|
||||
duration_ms: 1500,
|
||||
stun_probes: vec![],
|
||||
port_allocation: None,
|
||||
};
|
||||
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("Cone"));
|
||||
assert!(text.contains("203.0.113.5:4433"));
|
||||
assert!(text.contains("relay-1"));
|
||||
assert!(text.contains("1500ms"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_serializes_to_json() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Cone,
|
||||
reflexive_addr: Some("203.0.113.5:4433".into()),
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: Some(PortMapProtocol::NatPmp),
|
||||
relay_latencies: vec![],
|
||||
preferred_relay: None,
|
||||
stun_latency_ms: Some(25),
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: true,
|
||||
gateway: Some("192.168.1.1".into()),
|
||||
duration_ms: 500,
|
||||
stun_probes: vec![],
|
||||
port_allocation: Some(stun::PortAllocation::Sequential { delta: 1 }),
|
||||
};
|
||||
let json = serde_json::to_string(&report).unwrap();
|
||||
assert!(json.contains("Cone"));
|
||||
assert!(json.contains("203.0.113.5:4433"));
|
||||
assert!(json.contains("NatPmp"));
|
||||
|
||||
// Roundtrip
|
||||
let decoded: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(decoded["ipv4_reachable"], true);
|
||||
assert_eq!(decoded["ipv6_reachable"], false);
|
||||
assert_eq!(decoded["stun_latency_ms"], 25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relay_latency_serializes() {
|
||||
let lat = RelayLatency {
|
||||
name: "eu-west".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(42),
|
||||
error: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lat).unwrap();
|
||||
assert!(json.contains("eu-west"));
|
||||
assert!(json.contains("42"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_empty_relays() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::Unknown,
|
||||
reflexive_addr: None,
|
||||
ipv4_reachable: false,
|
||||
ipv6_reachable: false,
|
||||
hairpin_works: None,
|
||||
port_mapping: None,
|
||||
relay_latencies: vec![],
|
||||
preferred_relay: None,
|
||||
stun_latency_ms: None,
|
||||
upnp_available: false,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: None,
|
||||
duration_ms: 100,
|
||||
stun_probes: vec![],
|
||||
port_allocation: None,
|
||||
};
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("Unknown"));
|
||||
assert!(text.contains("(unknown)")); // reflexive addr
|
||||
assert!(text.contains("100ms"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_report_with_stun_probes() {
|
||||
let report = NetcheckReport {
|
||||
nat_type: NatType::SymmetricPort,
|
||||
reflexive_addr: None,
|
||||
ipv4_reachable: true,
|
||||
ipv6_reachable: true,
|
||||
hairpin_works: Some(false),
|
||||
port_mapping: Some(PortMapProtocol::UPnP),
|
||||
relay_latencies: vec![
|
||||
RelayLatency {
|
||||
name: "us-east".into(),
|
||||
addr: "10.0.0.1:4433".into(),
|
||||
rtt_ms: Some(15),
|
||||
error: None,
|
||||
},
|
||||
RelayLatency {
|
||||
name: "eu-west".into(),
|
||||
addr: "10.0.0.2:4433".into(),
|
||||
rtt_ms: None,
|
||||
error: Some("timeout".into()),
|
||||
},
|
||||
],
|
||||
preferred_relay: Some("us-east".into()),
|
||||
stun_latency_ms: Some(20),
|
||||
upnp_available: true,
|
||||
pcp_available: false,
|
||||
nat_pmp_available: false,
|
||||
gateway: Some("192.168.0.1".into()),
|
||||
duration_ms: 3000,
|
||||
stun_probes: vec![reflect::NatProbeResult {
|
||||
relay_name: "stun:google".into(),
|
||||
relay_addr: "74.125.250.129:19302".into(),
|
||||
observed_addr: Some("203.0.113.5:12345".into()),
|
||||
latency_ms: Some(20),
|
||||
error: None,
|
||||
}],
|
||||
port_allocation: Some(stun::PortAllocation::Random),
|
||||
};
|
||||
let text = format_report(&report);
|
||||
assert!(text.contains("SymmetricPort"));
|
||||
assert!(text.contains("us-east"));
|
||||
assert!(text.contains("eu-west"));
|
||||
assert!(text.contains("Preferred: us-east"));
|
||||
assert!(text.contains("UPnP: yes"));
|
||||
assert!(text.contains("stun:google"));
|
||||
assert!(text.contains("3000ms"));
|
||||
}
|
||||
|
||||
/// Integration test: run actual netcheck (requires network).
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn integration_netcheck() {
|
||||
let config = NetcheckConfig::default();
|
||||
let report = run_netcheck(&config).await;
|
||||
println!("{}", format_report(&report));
|
||||
assert!(report.duration_ms > 0);
|
||||
}
|
||||
}
|
||||
1164
crates/wzp-client/src/portmap.rs
Normal file
1164
crates/wzp-client/src/portmap.rs
Normal file
File diff suppressed because it is too large
Load Diff
704
crates/wzp-client/src/reflect.rs
Normal file
704
crates/wzp-client/src/reflect.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
//! Multi-relay NAT reflection ("STUN for QUIC" — Phase 2).
|
||||
//!
|
||||
//! Phase 1 (`SignalMessage::Reflect` / `ReflectResponse`) lets a
|
||||
//! client ask a single relay "what source address do you see for
|
||||
//! me?". Phase 2 queries N relays in parallel and classifies the
|
||||
//! results into a NAT type so the future P2P hole-punching path
|
||||
//! can decide whether a direct QUIC handshake is viable:
|
||||
//!
|
||||
//! - All relays return the same `(ip, port)` → **Cone NAT**.
|
||||
//! Endpoint-independent mapping, P2P hole-punching viable,
|
||||
//! `consensus_addr` is the one address to advertise.
|
||||
//! - Same ip, different ports → **Symmetric port-dependent NAT**.
|
||||
//! The mapping changes per destination, so the advertised addr
|
||||
//! wouldn't match what a peer actually sees; fall back to
|
||||
//! relay-mediated path.
|
||||
//! - Different ips → multi-homed / anycast / broken DNS, treat as
|
||||
//! `Multiple` and do not attempt P2P.
|
||||
//! - 0 or 1 successful probes → `Unknown`, not enough data.
|
||||
//!
|
||||
//! A probe is a throwaway QUIC signal connection: open endpoint,
|
||||
//! connect, RegisterPresence (with a zero identity — the relay
|
||||
//! accepts this exactly like the main signaling path does), send
|
||||
//! Reflect, read ReflectResponse, close. Each probe gets its own
|
||||
//! ephemeral quinn::Endpoint so the OS assigns a fresh source port
|
||||
//! per relay — if we shared one endpoint across probes, a
|
||||
//! symmetric NAT in front of the client would map every probe to
|
||||
//! the same port and we couldn't detect it.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
use wzp_proto::{MediaTransport, SignalMessage, default_signal_version};
|
||||
use wzp_transport::{QuinnTransport, client_config, create_endpoint};
|
||||
|
||||
/// Result of one probe against one relay. Always returned so the
|
||||
/// UI can render per-relay status even when some fail.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NatProbeResult {
|
||||
pub relay_name: String,
|
||||
pub relay_addr: String,
|
||||
/// `Some` on successful probe, `None` on failure.
|
||||
pub observed_addr: Option<String>,
|
||||
/// End-to-end wall-clock from connect start to ReflectResponse
|
||||
/// received, in milliseconds. `Some` only on success.
|
||||
pub latency_ms: Option<u32>,
|
||||
/// Human-readable error on failure.
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Aggregated classification over N `NatProbeResult`s.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NatDetection {
|
||||
pub probes: Vec<NatProbeResult>,
|
||||
pub nat_type: NatType,
|
||||
/// When `nat_type == Cone`, the one address all probes agreed
|
||||
/// on. `None` for every other case.
|
||||
pub consensus_addr: Option<String>,
|
||||
}
|
||||
|
||||
/// NAT classification. See module doc for semantics.
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
|
||||
pub enum NatType {
|
||||
Cone,
|
||||
SymmetricPort,
|
||||
Multiple,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Probe a single relay with a QUIC connection.
|
||||
///
|
||||
/// # Endpoint reuse (Phase 5 — Nebula-style architecture)
|
||||
///
|
||||
/// If `existing_endpoint` is `Some`, the probe uses that socket
|
||||
/// instead of creating a fresh one. This is the desired mode in
|
||||
/// production: a port-preserving NAT (MikroTik masquerade, most
|
||||
/// consumer routers) gives a **stable** external port for the
|
||||
/// one socket, so the reflex addr observed by ANY relay is the
|
||||
/// SAME addr and matches what a peer would see on a direct dial.
|
||||
/// Pass the signal endpoint here.
|
||||
///
|
||||
/// If `None`, creates a fresh one-shot endpoint. Kept for:
|
||||
/// - tests that spin up isolated probes
|
||||
/// - the "I'm not registered yet" case where there's no signal
|
||||
/// endpoint to reuse
|
||||
///
|
||||
/// NOTE on NAT-type detection: the pre-Phase-5 behavior of
|
||||
/// forcing a fresh endpoint per probe was wrong — it made every
|
||||
/// port-preserving NAT look symmetric because the classifier saw
|
||||
/// a different external port for each fresh source port. With
|
||||
/// one shared socket, the classifier reflects the REAL NAT
|
||||
/// behavior.
|
||||
pub async fn probe_reflect_addr(
|
||||
relay: SocketAddr,
|
||||
timeout_ms: u64,
|
||||
existing_endpoint: Option<wzp_transport::Endpoint>,
|
||||
) -> Result<(SocketAddr, u32), String> {
|
||||
// Install rustls provider idempotently — a second install on the
|
||||
// same thread is a no-op.
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let endpoint = match existing_endpoint {
|
||||
Some(ep) => ep,
|
||||
None => {
|
||||
let bind: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
create_endpoint(bind, None).map_err(|e| format!("endpoint: {e}"))?
|
||||
}
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
let probe = async {
|
||||
// Open the signal connection.
|
||||
let conn = wzp_transport::connect(&endpoint, relay, "_signal", client_config())
|
||||
.await
|
||||
.map_err(|e| format!("connect: {e}"))?;
|
||||
let transport = QuinnTransport::new(conn);
|
||||
|
||||
// The relay signal handler waits for a RegisterPresence
|
||||
// before entering its main dispatch loop (see
|
||||
// wzp-relay/src/main.rs). So a transient probe has to
|
||||
// register with a zero identity first — the relay accepts
|
||||
// the empty-signature form exactly as the main signaling
|
||||
// path does in desktop/src-tauri/src/lib.rs register_signal.
|
||||
transport
|
||||
.send_signal(&SignalMessage::RegisterPresence {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0u8; 32],
|
||||
signature: vec![],
|
||||
alias: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| format!("send RegisterPresence: {e}"))?;
|
||||
// Drain the RegisterPresenceAck so the response to our
|
||||
// Reflect doesn't land on an unexpected stream order.
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::RegisterPresenceAck { success: true, .. })) => {}
|
||||
Ok(Some(other)) => {
|
||||
return Err(format!(
|
||||
"unexpected pre-reflect signal: {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
));
|
||||
}
|
||||
Ok(None) => return Err("connection closed before RegisterPresenceAck".into()),
|
||||
Err(e) => return Err(format!("recv RegisterPresenceAck: {e}")),
|
||||
}
|
||||
|
||||
// Send Reflect and await response.
|
||||
transport
|
||||
.send_signal(&SignalMessage::Reflect)
|
||||
.await
|
||||
.map_err(|e| format!("send Reflect: {e}"))?;
|
||||
|
||||
match transport.recv_signal().await {
|
||||
Ok(Some(SignalMessage::ReflectResponse { observed_addr, .. })) => {
|
||||
let parsed: SocketAddr = observed_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("parse observed_addr {observed_addr:?}: {e}"))?;
|
||||
let latency_ms = start.elapsed().as_millis() as u32;
|
||||
|
||||
// Clean close so the relay's per-connection cleanup
|
||||
// runs promptly and we don't leak file descriptors.
|
||||
let _ = transport.close().await;
|
||||
|
||||
Ok((parsed, latency_ms))
|
||||
}
|
||||
Ok(Some(other)) => Err(format!(
|
||||
"expected ReflectResponse, got {:?}",
|
||||
std::mem::discriminant(&other)
|
||||
)),
|
||||
Ok(None) => Err("connection closed before ReflectResponse".into()),
|
||||
Err(e) => Err(format!("recv ReflectResponse: {e}")),
|
||||
}
|
||||
};
|
||||
|
||||
let out = tokio::time::timeout(Duration::from_millis(timeout_ms), probe)
|
||||
.await
|
||||
.map_err(|_| format!("probe timeout ({timeout_ms}ms)"))??;
|
||||
|
||||
// `endpoint` is a quinn::Endpoint clone — an Arc under the
|
||||
// hood. Letting it drop at end-of-scope is correct whether it
|
||||
// was fresh (last ref → socket closes) or shared (ref count
|
||||
// decrements, socket stays alive for the signal loop).
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Detect the client's NAT type by probing N relays in parallel and
|
||||
/// classifying the returned addresses. Never errors — failing
|
||||
/// probes surface via `NatProbeResult.error`; aggregate is always
|
||||
/// returned.
|
||||
///
|
||||
/// # Endpoint reuse (Phase 5)
|
||||
///
|
||||
/// If `shared_endpoint` is `Some`, every probe reuses it. This is
|
||||
/// the PRODUCTION behavior: all probes source from the same UDP
|
||||
/// port, so port-preserving NATs map them to the same external
|
||||
/// port, and the classifier reflects the real NAT type. Pass the
|
||||
/// signal endpoint.
|
||||
///
|
||||
/// If `None`, each probe creates its own fresh endpoint — useful
|
||||
/// in tests that don't have a signal endpoint, but produces
|
||||
/// spurious `SymmetricPort` classifications against NATs that
|
||||
/// would otherwise look cone-like.
|
||||
pub async fn detect_nat_type(
|
||||
relays: Vec<(String, SocketAddr)>,
|
||||
timeout_ms: u64,
|
||||
shared_endpoint: Option<wzp_transport::Endpoint>,
|
||||
) -> NatDetection {
|
||||
// Parallel probes via tokio::task::JoinSet so the wall-clock is
|
||||
// bounded by the slowest probe, not the sum. JoinSet keeps the
|
||||
// dep surface at just tokio — we already depend on it.
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
for (name, addr) in relays {
|
||||
let ep = shared_endpoint.clone();
|
||||
set.spawn(async move {
|
||||
let result = probe_reflect_addr(addr, timeout_ms, ep).await;
|
||||
(name, addr, result)
|
||||
});
|
||||
}
|
||||
|
||||
let mut probes = Vec::new();
|
||||
while let Some(join_result) = set.join_next().await {
|
||||
let (name, addr, result) = match join_result {
|
||||
Ok(tuple) => tuple,
|
||||
// Task panicked — surface as a synthetic failed probe so
|
||||
// the aggregate still returns a reasonable shape. This
|
||||
// shouldn't happen but we don't want one bad probe to
|
||||
// poison the whole detection.
|
||||
Err(join_err) => {
|
||||
probes.push(NatProbeResult {
|
||||
relay_name: "<panicked>".into(),
|
||||
relay_addr: "unknown".into(),
|
||||
observed_addr: None,
|
||||
latency_ms: None,
|
||||
error: Some(format!("probe task panicked: {join_err}")),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
};
|
||||
probes.push(match result {
|
||||
Ok((observed, latency_ms)) => NatProbeResult {
|
||||
relay_name: name,
|
||||
relay_addr: addr.to_string(),
|
||||
observed_addr: Some(observed.to_string()),
|
||||
latency_ms: Some(latency_ms),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => NatProbeResult {
|
||||
relay_name: name,
|
||||
relay_addr: addr.to_string(),
|
||||
observed_addr: None,
|
||||
latency_ms: None,
|
||||
error: Some(e),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let (nat_type, consensus_addr) = classify_nat(&probes);
|
||||
NatDetection {
|
||||
probes,
|
||||
nat_type,
|
||||
consensus_addr,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerate LAN-local host candidates this client is reachable
|
||||
/// on, paired with the given port (typically the signal
|
||||
/// endpoint's bound port so that incoming dials land on the same
|
||||
/// socket the advertised reflex addr points to).
|
||||
///
|
||||
/// Gathers BOTH IPv4 and IPv6 candidates:
|
||||
///
|
||||
/// - **IPv4**: RFC1918 private ranges (10/8, 172.16/12, 192.168/16)
|
||||
/// and CGNAT shared-transition (100.64/10). Public IPv4 is
|
||||
/// skipped because the reflex-addr path already covers it.
|
||||
/// Loopback and link-local (169.254/16) are skipped.
|
||||
///
|
||||
/// - **IPv6**: ALL global-unicast addresses (2000::/3 — the real
|
||||
/// routable IPv6 space) AND unique-local (fc00::/7). These
|
||||
/// are directly dialable from a peer on the same LAN, and on
|
||||
/// true dual-stack LANs (which most consumer ISPs now provide,
|
||||
/// including Starlink) IPv6 often gives a direct path even
|
||||
/// when IPv4 can't hairpin. Loopback (::1), unspecified (::),
|
||||
/// and link-local (fe80::/10) are skipped — link-local would
|
||||
/// require a scope ID to be useful and is basically never
|
||||
/// reachable across interface boundaries.
|
||||
///
|
||||
/// The port must come from the caller — typically
|
||||
/// `signal_endpoint.local_addr()?.port()`, so that the peer's
|
||||
/// dials to these addresses land on the same socket that's
|
||||
/// already listening (Phase 5 shared-endpoint architecture).
|
||||
///
|
||||
/// Safe to call from any thread; no I/O, no async. The `if-addrs`
|
||||
/// crate reads the kernel's interface table via a single
|
||||
/// getifaddrs(3) syscall.
|
||||
pub fn local_host_candidates(v4_port: u16, v6_port: Option<u16>) -> Vec<SocketAddr> {
|
||||
let Ok(ifaces) = if_addrs::get_if_addrs() else {
|
||||
return Vec::new();
|
||||
};
|
||||
let mut out = Vec::new();
|
||||
for iface in ifaces {
|
||||
if iface.is_loopback() {
|
||||
continue;
|
||||
}
|
||||
match iface.ip() {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
if v4.is_link_local() {
|
||||
continue;
|
||||
}
|
||||
// Keep RFC1918 private ranges and CGNAT — those
|
||||
// are the LAN-dialable addrs we actually want.
|
||||
// Skip public v4 because the reflex addr already
|
||||
// covers that path.
|
||||
if v4.is_private() {
|
||||
out.push(SocketAddr::new(std::net::IpAddr::V4(v4), v4_port));
|
||||
} else if v4.octets()[0] == 100 && (v4.octets()[1] & 0xc0) == 0x40 {
|
||||
// 100.64/10 CGNAT — rare but valid if two
|
||||
// phones are on the same CGNAT-hairpinned
|
||||
// carrier LAN (some hotspot setups).
|
||||
out.push(SocketAddr::new(std::net::IpAddr::V4(v4), v4_port));
|
||||
}
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
// Phase 7: IPv6 host candidates via dedicated
|
||||
// IPv6 socket. When v6_port is None, no IPv6
|
||||
// endpoint exists — skip silently.
|
||||
let Some(port) = v6_port else { continue };
|
||||
if v6.is_loopback() || v6.is_unspecified() {
|
||||
continue;
|
||||
}
|
||||
// fe80::/10 link-local — needs scope ID, not
|
||||
// routable across interfaces.
|
||||
if (v6.segments()[0] & 0xffc0) == 0xfe80 {
|
||||
continue;
|
||||
}
|
||||
// Accept global unicast (2000::/3) and
|
||||
// unique-local (fc00::/7).
|
||||
let first_seg = v6.segments()[0];
|
||||
let is_global = (first_seg & 0xe000) == 0x2000;
|
||||
let is_ula = (first_seg & 0xfe00) == 0xfc00;
|
||||
if is_global || is_ula {
|
||||
out.push(SocketAddr::new(std::net::IpAddr::V6(v6), port));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Role assignment for the Phase 3.5 dual-path QUIC race.
|
||||
///
|
||||
/// Both peers already know two strings at CallSetup time: their
|
||||
/// own server-reflexive address (queried via Phase 1 Reflect) and
|
||||
/// the peer's (carried in `CallSetup.peer_direct_addr`). To avoid
|
||||
/// a negotiation round-trip, both sides compare the two strings
|
||||
/// lexicographically and agree on a deterministic role:
|
||||
///
|
||||
/// - **Acceptor** — lexicographically smaller addr. Listens for
|
||||
/// an incoming direct connection from the peer. Does NOT dial.
|
||||
/// - **Dialer** — lexicographically larger addr. Dials the
|
||||
/// peer's direct addr. Does NOT listen.
|
||||
///
|
||||
/// Both roles ALSO dial the relay in parallel as a fallback.
|
||||
/// Whichever future (direct or relay) completes first is used as
|
||||
/// the media transport. Because the role is deterministic and
|
||||
/// symmetric, both peers end up holding the same underlying QUIC
|
||||
/// session on the direct path — A's accepted conn and D's dialed
|
||||
/// conn are literally the same connection.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Role {
|
||||
/// This peer listens for the direct incoming connection.
|
||||
Acceptor,
|
||||
/// This peer dials the peer's direct address.
|
||||
Dialer,
|
||||
}
|
||||
|
||||
/// Compute the deterministic role for this peer in the dual-path
|
||||
/// race. Returns `None` when no direct attempt is possible —
|
||||
/// either peer didn't advertise a reflex addr, or the two addrs
|
||||
/// are identical (same host on loopback / mis-advertised).
|
||||
///
|
||||
/// The caller should treat `None` as "skip direct, relay-only".
|
||||
pub fn determine_role(
|
||||
own_reflex_addr: Option<&str>,
|
||||
peer_reflex_addr: Option<&str>,
|
||||
) -> Option<Role> {
|
||||
let (own, peer) = match (own_reflex_addr, peer_reflex_addr) {
|
||||
(Some(o), Some(p)) => (o, p),
|
||||
_ => return None,
|
||||
};
|
||||
match own.cmp(peer) {
|
||||
std::cmp::Ordering::Less => Some(Role::Acceptor),
|
||||
std::cmp::Ordering::Greater => Some(Role::Dialer),
|
||||
// Equal addrs should never happen in production (both
|
||||
// peers behind the same NAT mapping + same port would be
|
||||
// a degenerate case). Guard against it so we don't infinite-
|
||||
// loop waiting for a connection to ourselves.
|
||||
std::cmp::Ordering::Equal => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the address is in an RFC1918 / link-local /
|
||||
/// loopback range and therefore cannot possibly be a post-NAT
|
||||
/// reflex address from the public internet's point of view.
|
||||
///
|
||||
/// A probe against a relay ON THE SAME LAN as the client will
|
||||
/// naturally report the client's LAN IP back (because there's no
|
||||
/// NAT between them) — that observation is real but says nothing
|
||||
/// about the client's public-internet-facing NAT state. Mixing
|
||||
/// LAN reflex addrs with public-internet reflex addrs in
|
||||
/// `classify_nat` would always report `Multiple` (different IPs)
|
||||
/// and falsely warn about symmetric NAT. Filter them out before
|
||||
/// classifying.
|
||||
fn is_private_or_loopback(addr: &SocketAddr) -> bool {
|
||||
match addr.ip() {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
let o = v4.octets();
|
||||
v4.is_loopback()
|
||||
|| v4.is_private() // 10/8, 172.16/12, 192.168/16
|
||||
|| v4.is_link_local() // 169.254/16
|
||||
|| (o[0] == 100 && (o[1] & 0xc0) == 0x40) // 100.64/10 CGNAT shared
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xffc0) == 0xfe80 // fe80::/10 link-local
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pure-function NAT classifier — split out for unit testing
|
||||
/// without touching the network.
|
||||
///
|
||||
/// Only considers probes whose reflex addr is a **public-internet**
|
||||
/// address. LAN / private / loopback reflex addrs are dropped
|
||||
/// because they reflect the same-network path rather than the
|
||||
/// real NAT state. CGNAT (100.64/10) is also treated as private
|
||||
/// because the post-CGNAT address would be what we actually want
|
||||
/// to classify on — but CGNAT is unreachable from outside the
|
||||
/// carrier, so a relay seeing the CGNAT addr is on the same
|
||||
/// carrier network and again not useful for classification.
|
||||
pub fn classify_nat(probes: &[NatProbeResult]) -> (NatType, Option<String>) {
|
||||
// First: parse every successful probe's observed addr.
|
||||
let parsed: Vec<SocketAddr> = probes
|
||||
.iter()
|
||||
.filter_map(|p| p.observed_addr.as_deref().and_then(|s| s.parse().ok()))
|
||||
.collect();
|
||||
|
||||
// Then: drop LAN / private / loopback reflex addrs. Those are
|
||||
// legitimate observations by same-network relays, but they
|
||||
// don't contribute to NAT-type classification because the
|
||||
// client's real public-facing NAT mapping is not involved on
|
||||
// that path. A relay on the same LAN always sees the client's
|
||||
// LAN IP, regardless of whether the NAT beyond it is cone or
|
||||
// symmetric.
|
||||
let successes: Vec<SocketAddr> = parsed
|
||||
.into_iter()
|
||||
.filter(|a| !is_private_or_loopback(a))
|
||||
.collect();
|
||||
|
||||
if successes.len() < 2 {
|
||||
return (NatType::Unknown, None);
|
||||
}
|
||||
|
||||
let first = successes[0];
|
||||
let same_ip = successes.iter().all(|a| a.ip() == first.ip());
|
||||
if !same_ip {
|
||||
return (NatType::Multiple, None);
|
||||
}
|
||||
|
||||
let same_port = successes.iter().all(|a| a.port() == first.port());
|
||||
if same_port {
|
||||
(NatType::Cone, Some(first.to_string()))
|
||||
} else {
|
||||
(NatType::SymmetricPort, None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Enhanced NAT detection that combines relay-based reflection with
|
||||
/// public STUN server probes for more robust classification.
|
||||
///
|
||||
/// Runs both probe sets concurrently:
|
||||
/// 1. Relay probes via `detect_nat_type` (existing behavior)
|
||||
/// 2. Public STUN probes via `probe_stun_servers`
|
||||
///
|
||||
/// Merges all results and classifies. More probes = higher confidence
|
||||
/// in the NAT type classification. Falls back gracefully: if STUN
|
||||
/// servers are unreachable, relay probes still work (and vice versa).
|
||||
pub async fn detect_nat_type_with_stun(
|
||||
relays: Vec<(String, SocketAddr)>,
|
||||
timeout_ms: u64,
|
||||
shared_endpoint: Option<wzp_transport::Endpoint>,
|
||||
stun_config: &crate::stun::StunConfig,
|
||||
) -> NatDetection {
|
||||
// Run relay probes and STUN probes concurrently.
|
||||
let relay_fut = detect_nat_type(relays, timeout_ms, shared_endpoint);
|
||||
let stun_fut = crate::stun::probe_stun_servers(stun_config);
|
||||
|
||||
let (relay_detection, stun_probes) = tokio::join!(relay_fut, stun_fut);
|
||||
|
||||
// Merge all probes and re-classify.
|
||||
let mut all_probes = relay_detection.probes;
|
||||
all_probes.extend(stun_probes);
|
||||
|
||||
let (nat_type, consensus_addr) = classify_nat(&all_probes);
|
||||
NatDetection {
|
||||
probes: all_probes,
|
||||
nat_type,
|
||||
consensus_addr,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Unit tests for the pure classifier ───────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mk(addr: Option<&str>) -> NatProbeResult {
|
||||
NatProbeResult {
|
||||
relay_name: "test".into(),
|
||||
relay_addr: "0.0.0.0:0".into(),
|
||||
observed_addr: addr.map(|s| s.to_string()),
|
||||
latency_ms: addr.map(|_| 10),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_empty_is_unknown() {
|
||||
let (nt, addr) = classify_nat(&[]);
|
||||
assert_eq!(nt, NatType::Unknown);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_single_success_is_unknown() {
|
||||
let probes = vec![mk(Some("192.0.2.1:4433"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Unknown);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_two_identical_is_cone() {
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:4433"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Cone);
|
||||
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_same_ip_different_ports_is_symmetric() {
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("192.0.2.1:51234"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::SymmetricPort);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_different_ips_is_multiple() {
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(Some("198.51.100.9:4433"))];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Multiple);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_drops_private_ip_probes() {
|
||||
// One LAN probe + one public probe should behave like a
|
||||
// single public probe — i.e. Unknown (not enough data to
|
||||
// classify). This is the common real-world case: the user
|
||||
// has a LAN relay + an internet relay configured, the LAN
|
||||
// relay sees the LAN IP, the internet relay sees the WAN
|
||||
// IP, and the old classifier would flag "Multiple" and
|
||||
// falsely warn about symmetric NAT.
|
||||
let probes = vec![
|
||||
mk(Some("192.168.1.100:4433")), // LAN — must be dropped
|
||||
mk(Some("203.0.113.5:4433")), // public (TEST-NET-3)
|
||||
];
|
||||
let (nt, _) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Unknown);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_drops_loopback_probes() {
|
||||
let probes = vec![
|
||||
mk(Some("127.0.0.1:4433")), // loopback — must be dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:4433")), // public, same addr
|
||||
];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
// Two public probes with identical addrs → Cone.
|
||||
assert_eq!(nt, NatType::Cone);
|
||||
assert_eq!(addr.as_deref(), Some("203.0.113.5:4433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_drops_cgnat_probes() {
|
||||
// 100.64.0.0/10 is the CGNAT shared-transition range.
|
||||
// Filter treats it like RFC1918 — a relay that sees the
|
||||
// client with a 100.64/10 addr is on the same CGNAT
|
||||
// network and can't contribute to public NAT classification.
|
||||
let probes = vec![
|
||||
mk(Some("100.64.0.42:4433")), // CGNAT — dropped
|
||||
mk(Some("203.0.113.5:4433")), // public
|
||||
mk(Some("203.0.113.5:12345")), // public, different port
|
||||
];
|
||||
let (nt, _) = classify_nat(&probes);
|
||||
// Two public probes same IP different port → SymmetricPort.
|
||||
assert_eq!(nt, NatType::SymmetricPort);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_two_lan_probes_is_unknown_not_cone() {
|
||||
// Even if both probes come back from LAN relays, we can't
|
||||
// say anything useful about the public NAT state. Unknown,
|
||||
// not Cone.
|
||||
let probes = vec![
|
||||
mk(Some("192.168.1.100:4433")),
|
||||
mk(Some("192.168.1.100:4433")),
|
||||
];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Unknown);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_mix_of_success_and_failure() {
|
||||
let probes = vec![
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
mk(None), // failed probe
|
||||
mk(Some("192.0.2.1:4433")),
|
||||
];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
// Two successes both agree → Cone, ignore the failure row.
|
||||
assert_eq!(nt, NatType::Cone);
|
||||
assert_eq!(addr.as_deref(), Some("192.0.2.1:4433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_smaller_is_acceptor() {
|
||||
// Lexicographic: "192.0.2.1:4433" < "198.51.100.9:4433"
|
||||
assert_eq!(
|
||||
determine_role(Some("192.0.2.1:4433"), Some("198.51.100.9:4433")),
|
||||
Some(Role::Acceptor)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_larger_is_dialer() {
|
||||
assert_eq!(
|
||||
determine_role(Some("198.51.100.9:4433"), Some("192.0.2.1:4433")),
|
||||
Some(Role::Dialer)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_port_difference_matters() {
|
||||
// Same ip, different ports — string compare still works
|
||||
// because "4433" < "54321".
|
||||
assert_eq!(
|
||||
determine_role(Some("127.0.0.1:4433"), Some("127.0.0.1:54321")),
|
||||
Some(Role::Acceptor)
|
||||
);
|
||||
assert_eq!(
|
||||
determine_role(Some("127.0.0.1:54321"), Some("127.0.0.1:4433")),
|
||||
Some(Role::Dialer)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_equal_addrs_is_none() {
|
||||
assert_eq!(
|
||||
determine_role(Some("192.0.2.1:4433"), Some("192.0.2.1:4433")),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_missing_side_is_none() {
|
||||
assert_eq!(determine_role(None, Some("192.0.2.1:4433")), None);
|
||||
assert_eq!(determine_role(Some("192.0.2.1:4433"), None), None);
|
||||
assert_eq!(determine_role(None, None), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn determine_role_is_symmetric_across_peers() {
|
||||
// Both peers compute roles independently; they must end
|
||||
// up with opposite assignments (one Acceptor, one Dialer)
|
||||
// so that each side ends up talking to the other.
|
||||
let a = "192.0.2.1:4433";
|
||||
let b = "198.51.100.9:4433";
|
||||
let alice_role = determine_role(Some(a), Some(b));
|
||||
let bob_role = determine_role(Some(b), Some(a));
|
||||
assert_eq!(alice_role, Some(Role::Acceptor));
|
||||
assert_eq!(bob_role, Some(Role::Dialer));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_one_success_one_failure_is_unknown() {
|
||||
let probes = vec![mk(Some("192.0.2.1:4433")), mk(None)];
|
||||
let (nt, addr) = classify_nat(&probes);
|
||||
assert_eq!(nt, NatType::Unknown);
|
||||
assert!(addr.is_none());
|
||||
}
|
||||
}
|
||||
337
crates/wzp-client/src/relay_map.rs
Normal file
337
crates/wzp-client/src/relay_map.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Phase 8 (Tailscale-inspired): Relay map for automatic relay
|
||||
//! selection based on latency.
|
||||
//!
|
||||
//! Maintains a sorted list of known relays with their measured
|
||||
//! latencies. Used during call setup to pick the lowest-latency
|
||||
//! relay, and by netcheck to report relay health.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
/// A known relay endpoint with measured latency.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RelayEntry {
|
||||
/// Human-readable name (e.g., "us-east", "eu-west").
|
||||
pub name: String,
|
||||
/// Relay address.
|
||||
pub addr: SocketAddr,
|
||||
/// Geographic region (from RegisterPresenceAck).
|
||||
pub region: Option<String>,
|
||||
/// Last measured RTT (ms).
|
||||
pub rtt_ms: Option<u32>,
|
||||
/// When the RTT was last measured.
|
||||
#[serde(skip)]
|
||||
pub last_probed: Option<Instant>,
|
||||
/// Whether this relay is currently reachable.
|
||||
pub reachable: bool,
|
||||
}
|
||||
|
||||
/// Sorted relay map. Entries are ordered by RTT (lowest first).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RelayMap {
|
||||
entries: Vec<RelayEntry>,
|
||||
}
|
||||
|
||||
impl RelayMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add or update a relay entry.
|
||||
pub fn upsert(&mut self, name: &str, addr: SocketAddr, region: Option<String>) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.name = name.to_string();
|
||||
if region.is_some() {
|
||||
entry.region = region;
|
||||
}
|
||||
} else {
|
||||
self.entries.push(RelayEntry {
|
||||
name: name.to_string(),
|
||||
addr,
|
||||
region,
|
||||
rtt_ms: None,
|
||||
last_probed: None,
|
||||
reachable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Update RTT measurement for a relay.
|
||||
pub fn update_rtt(&mut self, addr: SocketAddr, rtt_ms: u32) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.rtt_ms = Some(rtt_ms);
|
||||
entry.last_probed = Some(Instant::now());
|
||||
entry.reachable = true;
|
||||
}
|
||||
self.sort();
|
||||
}
|
||||
|
||||
/// Mark a relay as unreachable.
|
||||
pub fn mark_unreachable(&mut self, addr: SocketAddr) {
|
||||
if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) {
|
||||
entry.reachable = false;
|
||||
entry.last_probed = Some(Instant::now());
|
||||
}
|
||||
self.sort();
|
||||
}
|
||||
|
||||
/// Get the preferred (lowest-latency, reachable) relay.
|
||||
pub fn preferred(&self) -> Option<&RelayEntry> {
|
||||
self.entries
|
||||
.iter()
|
||||
.find(|e| e.reachable && e.rtt_ms.is_some())
|
||||
}
|
||||
|
||||
/// Get all entries, sorted by RTT.
|
||||
pub fn entries(&self) -> &[RelayEntry] {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
/// Populate from a `RegisterPresenceAck.available_relays` list.
|
||||
/// Each entry is "name|addr" format.
|
||||
pub fn populate_from_ack(&mut self, relays: &[String], relay_region: Option<&str>) {
|
||||
for entry_str in relays {
|
||||
if let Some((name, addr_str)) = entry_str.split_once('|') {
|
||||
if let Ok(addr) = addr_str.parse::<SocketAddr>() {
|
||||
self.upsert(name, addr, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
// If the ack included a region for the current relay, we
|
||||
// could tag it — but we'd need to know which relay we're
|
||||
// connected to. Left for the caller to handle.
|
||||
let _ = relay_region;
|
||||
}
|
||||
|
||||
/// Check if any entry has a stale probe (older than `max_age`).
|
||||
pub fn needs_reprobe(&self, max_age: Duration) -> bool {
|
||||
self.entries.iter().any(|e| match e.last_probed {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > max_age,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get entries that need reprobing.
|
||||
pub fn stale_entries(&self, max_age: Duration) -> Vec<(String, SocketAddr)> {
|
||||
self.entries
|
||||
.iter()
|
||||
.filter(|e| match e.last_probed {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > max_age,
|
||||
})
|
||||
.map(|e| (e.name.clone(), e.addr))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sort(&mut self) {
|
||||
self.entries.sort_by_key(|e| {
|
||||
if e.reachable {
|
||||
e.rtt_ms.unwrap_or(u32::MAX)
|
||||
} else {
|
||||
u32::MAX
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_lowest_rtt() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
let a3: SocketAddr = "10.0.0.3:4433".parse().unwrap();
|
||||
|
||||
map.upsert("slow", a1, None);
|
||||
map.upsert("fast", a2, None);
|
||||
map.upsert("mid", a3, None);
|
||||
|
||||
map.update_rtt(a1, 200);
|
||||
map.update_rtt(a2, 15);
|
||||
map.update_rtt(a3, 80);
|
||||
|
||||
let pref = map.preferred().unwrap();
|
||||
assert_eq!(pref.addr, a2);
|
||||
assert_eq!(pref.rtt_ms, Some(15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unreachable_not_preferred() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
|
||||
map.upsert("fast-dead", a1, None);
|
||||
map.upsert("slow-alive", a2, None);
|
||||
|
||||
map.update_rtt(a1, 5);
|
||||
map.update_rtt(a2, 200);
|
||||
map.mark_unreachable(a1);
|
||||
|
||||
let pref = map.preferred().unwrap();
|
||||
assert_eq!(pref.addr, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn populate_from_ack() {
|
||||
let mut map = RelayMap::new();
|
||||
map.populate_from_ack(
|
||||
&[
|
||||
"us-east|203.0.113.5:4433".into(),
|
||||
"eu-west|198.51.100.9:4433".into(),
|
||||
],
|
||||
Some("us-east"),
|
||||
);
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
assert_eq!(map.entries()[0].name, "us-east");
|
||||
assert_eq!(map.entries()[1].name, "eu-west");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_updates_existing() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("old-name", addr, None);
|
||||
map.upsert("new-name", addr, Some("us-west".into()));
|
||||
assert_eq!(map.entries().len(), 1);
|
||||
assert_eq!(map.entries()[0].name, "new-name");
|
||||
assert_eq!(map.entries()[0].region, Some("us-west".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_preserves_region_when_none() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, Some("eu-west".into()));
|
||||
map.upsert("relay", addr, None); // region is None
|
||||
// Should keep the original region
|
||||
assert_eq!(map.entries()[0].region, Some("eu-west".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_none_on_empty() {
|
||||
let map = RelayMap::new();
|
||||
assert!(map.preferred().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preferred_returns_none_when_all_unreachable() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, None);
|
||||
// Not update_rtt'd, so reachable=false
|
||||
assert!(map.preferred().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_empty_is_false() {
|
||||
let map = RelayMap::new();
|
||||
// No entries → nothing to reprobe
|
||||
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_never_probed() {
|
||||
let mut map = RelayMap::new();
|
||||
map.upsert("relay", "10.0.0.1:4433".parse().unwrap(), None);
|
||||
assert!(map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_reprobe_fresh_is_false() {
|
||||
let mut map = RelayMap::new();
|
||||
let addr: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
map.upsert("relay", addr, None);
|
||||
map.update_rtt(addr, 50);
|
||||
// Just probed, so 60s max_age should not trigger
|
||||
assert!(!map.needs_reprobe(Duration::from_secs(60)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_entries_returns_unprobed() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("probed", a1, None);
|
||||
map.upsert("stale", a2, None);
|
||||
map.update_rtt(a1, 50);
|
||||
|
||||
let stale = map.stale_entries(Duration::from_secs(60));
|
||||
assert_eq!(stale.len(), 1);
|
||||
assert_eq!(stale[0].1, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sort_stability_with_equal_rtt() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("first", a1, None);
|
||||
map.upsert("second", a2, None);
|
||||
map.update_rtt(a1, 50);
|
||||
map.update_rtt(a2, 50);
|
||||
|
||||
// Both have same RTT — sort should be stable (insertion order)
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
// Both are valid preferred relays
|
||||
assert!(map.preferred().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn populate_from_ack_skips_malformed() {
|
||||
let mut map = RelayMap::new();
|
||||
map.populate_from_ack(
|
||||
&[
|
||||
"good|10.0.0.1:4433".into(),
|
||||
"no-pipe-separator".into(),
|
||||
"bad-addr|not-a-socket-addr".into(),
|
||||
"also-good|10.0.0.2:4433".into(),
|
||||
],
|
||||
None,
|
||||
);
|
||||
assert_eq!(map.entries().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mark_unreachable_sorts_to_end() {
|
||||
let mut map = RelayMap::new();
|
||||
let a1: SocketAddr = "10.0.0.1:4433".parse().unwrap();
|
||||
let a2: SocketAddr = "10.0.0.2:4433".parse().unwrap();
|
||||
map.upsert("fast", a1, None);
|
||||
map.upsert("slow", a2, None);
|
||||
map.update_rtt(a1, 10);
|
||||
map.update_rtt(a2, 200);
|
||||
|
||||
assert_eq!(map.preferred().unwrap().addr, a1);
|
||||
|
||||
map.mark_unreachable(a1);
|
||||
assert_eq!(map.preferred().unwrap().addr, a2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relay_entry_serializes() {
|
||||
let entry = RelayEntry {
|
||||
name: "test".into(),
|
||||
addr: "10.0.0.1:4433".parse().unwrap(),
|
||||
region: Some("us-east".into()),
|
||||
rtt_ms: Some(42),
|
||||
last_probed: Some(Instant::now()),
|
||||
reachable: true,
|
||||
};
|
||||
let json = serde_json::to_string(&entry).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
assert!(json.contains("us-east"));
|
||||
assert!(json.contains("42"));
|
||||
// last_probed is #[serde(skip)]
|
||||
assert!(!json.contains("last_probed"));
|
||||
}
|
||||
}
|
||||
1445
crates/wzp-client/src/stun.rs
Normal file
1445
crates/wzp-client/src/stun.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -72,8 +72,7 @@ fn sine_frame(freq_hz: f32, frame_offset: u64) -> Vec<i16> {
|
||||
/// decoder, pushes frames through the pipeline, and collects statistics.
|
||||
/// Combinations where `target_depth > max_depth` are skipped.
|
||||
pub fn run_local_sweep(config: &SweepConfig) -> Vec<SweepResult> {
|
||||
let frames_per_config =
|
||||
(config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||
let frames_per_config = (config.test_duration_secs as u64) * (1000 / FRAME_DURATION_MS as u64);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
|
||||
232
crates/wzp-client/tests/dual_path.rs
Normal file
232
crates/wzp-client/tests/dual_path.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
//! Phase 3.5 integration tests for the dual-path QUIC race.
|
||||
//!
|
||||
//! The race takes a role (Acceptor or Dialer), a peer_direct_addr,
|
||||
//! a relay_addr, and two SNI strings, then returns whichever QUIC
|
||||
//! handshake completes first wrapped in a `QuinnTransport`. These
|
||||
//! tests validate that:
|
||||
//!
|
||||
//! 1. On loopback with two real clients playing A + D roles, the
|
||||
//! direct path wins (fewer hops than relay).
|
||||
//! 2. When the direct peer is dead (nothing listening) but the
|
||||
//! relay is up, the relay wins within the fallback window.
|
||||
//! 3. When both paths are dead, the race errors cleanly rather
|
||||
//! than hanging forever.
|
||||
//!
|
||||
//! The "relay" in these tests is a minimal mock that just accepts
|
||||
//! an incoming QUIC connection and drops it — we don't need any
|
||||
//! protocol handling, just a TCP-ish listen-and-accept.
|
||||
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use wzp_client::dual_path::{PeerCandidates, WinningPath, race};
|
||||
use wzp_client::reflect::Role;
|
||||
use wzp_transport::{create_endpoint, server_config};
|
||||
|
||||
/// Spin up a "relay-ish" mock server on loopback that accepts
|
||||
/// incoming QUIC connections and does nothing with them. Used to
|
||||
/// give the relay branch of the race a real target to dial.
|
||||
/// Returns the bound address + a join handle (kept alive to keep
|
||||
/// the endpoint up).
|
||||
async fn spawn_mock_relay() -> (SocketAddr, tokio::task::JoinHandle<()>) {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let (sc, _cert_der) = server_config();
|
||||
let bind: SocketAddr = (Ipv4Addr::LOCALHOST, 0).into();
|
||||
let ep = create_endpoint(bind, Some(sc)).expect("relay endpoint");
|
||||
let addr = ep.local_addr().expect("local_addr");
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
// Accept loop — hold the connection alive for a short
|
||||
// while so the race result isn't killed by the peer
|
||||
// closing before the winning transport is returned.
|
||||
while let Some(incoming) = ep.accept().await {
|
||||
if let Ok(_conn) = incoming.await {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
(addr, handle)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 1: direct path wins when both sides are up
|
||||
// -----------------------------------------------------------------------
|
||||
//
|
||||
// Spawn a mock relay, then set up a two-client test where one
|
||||
// client plays the Acceptor role and the other plays the Dialer
|
||||
// role. The Dialer's `peer_direct_addr` is the Acceptor's listen
|
||||
// address. Because the direct path is a single loopback hop and
|
||||
// the relay dial also terminates on loopback, both complete
|
||||
// essentially instantly — the `biased` tokio::select in race()
|
||||
// should pick direct.
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn dual_path_direct_wins_on_loopback() {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let (relay_addr, _relay_handle) = spawn_mock_relay().await;
|
||||
|
||||
// Acceptor task: run race(Role::Acceptor, peer_addr_placeholder, ...).
|
||||
// Since the acceptor doesn't dial, the peer_direct_addr arg is
|
||||
// unused on the direct branch but we still pass a placeholder
|
||||
// because the API takes one. Use a stub addr that would error
|
||||
// if it were ever dialed — proving the Acceptor really doesn't
|
||||
// reach it.
|
||||
let unused_addr: SocketAddr = "127.0.0.1:2".parse().unwrap();
|
||||
|
||||
// We can't race both sides in the same task because each race
|
||||
// call has its own direct endpoint that needs to talk to the
|
||||
// OTHER side's endpoint. So spawn the Acceptor in a task and
|
||||
// let it expose its listen addr via a oneshot back to the test,
|
||||
// then run the Dialer in the test's main task.
|
||||
//
|
||||
// There's a chicken-and-egg issue: the Acceptor's listen addr
|
||||
// is only known after race() creates its endpoint. To avoid
|
||||
// reaching into race()'s internals, we instead play a slight
|
||||
// trick: create the Acceptor's endpoint ourselves (outside
|
||||
// race()) to learn its addr, spin up an accept loop on it
|
||||
// ourselves, and pass THAT addr as the Dialer's peer addr.
|
||||
// This tests the Dialer->Acceptor handshake end-to-end without
|
||||
// running the full race() on both sides.
|
||||
|
||||
let (sc, _cert_der) = server_config();
|
||||
let acceptor_bind: SocketAddr = (Ipv4Addr::LOCALHOST, 0).into();
|
||||
let acceptor_ep = create_endpoint(acceptor_bind, Some(sc)).expect("acceptor ep");
|
||||
let acceptor_listen_addr = acceptor_ep.local_addr().expect("acceptor addr");
|
||||
|
||||
// Drop the external acceptor after the test finishes, not
|
||||
// before — spawn a dedicated accept task.
|
||||
let acceptor_accept_task = tokio::spawn(async move {
|
||||
// Accept one connection and hold it for a while so the
|
||||
// Dialer side can complete its QUIC handshake.
|
||||
if let Some(incoming) = acceptor_ep.accept().await {
|
||||
if let Ok(_conn) = incoming.await {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Now run the Dialer in the race — peer_direct_addr = acceptor's
|
||||
// listen addr. The relay is the mock from above. Direct path
|
||||
// should win.
|
||||
let result = race(
|
||||
Role::Dialer,
|
||||
PeerCandidates {
|
||||
reflexive: Some(acceptor_listen_addr),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
relay_addr,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await
|
||||
.expect("race must succeed");
|
||||
|
||||
assert!(
|
||||
result.direct_transport.is_some(),
|
||||
"direct transport should be available"
|
||||
);
|
||||
assert_eq!(
|
||||
result.local_winner,
|
||||
WinningPath::Direct,
|
||||
"direct should win on loopback"
|
||||
);
|
||||
|
||||
// Cancel the acceptor accept task so the test finishes.
|
||||
acceptor_accept_task.abort();
|
||||
// Suppress unused-var warning for the placeholder.
|
||||
let _ = unused_addr;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 2: relay wins when the direct peer is dead
|
||||
// -----------------------------------------------------------------------
|
||||
//
|
||||
// Dialer role, peer_direct_addr = a port nothing is listening on,
|
||||
// relay is the working mock. Direct dial will sit waiting for a
|
||||
// QUIC handshake that never comes; the 2s direct timeout kicks in
|
||||
// and the relay path wins the fallback.
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn dual_path_relay_wins_when_direct_is_dead() {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let (relay_addr, _relay_handle) = spawn_mock_relay().await;
|
||||
|
||||
// A port that nothing is listening on — dead direct target.
|
||||
// Port 1 on loopback is almost never bound and UDP packets to
|
||||
// it will be dropped silently, so the QUIC handshake times out.
|
||||
let dead_peer: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
|
||||
let result = race(
|
||||
Role::Dialer,
|
||||
PeerCandidates {
|
||||
reflexive: Some(dead_peer),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
relay_addr,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await
|
||||
.expect("race must succeed via relay fallback");
|
||||
|
||||
assert!(
|
||||
result.relay_transport.is_some(),
|
||||
"relay transport should be available"
|
||||
);
|
||||
assert_eq!(
|
||||
result.local_winner,
|
||||
WinningPath::Relay,
|
||||
"relay should win when direct dial has nowhere to land"
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Test 3: race errors cleanly when both paths are dead
|
||||
// -----------------------------------------------------------------------
|
||||
//
|
||||
// Dialer role, peer_direct_addr = dead, relay_addr = dead.
|
||||
// Expected: race returns an Err within ~7s (2s direct timeout +
|
||||
// 5s relay timeout fallback).
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn dual_path_errors_cleanly_when_both_paths_dead() {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let dead_peer: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let dead_relay: SocketAddr = "127.0.0.1:2".parse().unwrap();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let result = race(
|
||||
Role::Dialer,
|
||||
PeerCandidates {
|
||||
reflexive: Some(dead_peer),
|
||||
local: Vec::new(),
|
||||
mapped: None,
|
||||
},
|
||||
dead_relay,
|
||||
"test-room".into(),
|
||||
"call-test".into(),
|
||||
None, // own_reflexive: not needed in tests
|
||||
None, // Phase 5: tests use fresh endpoints (no shared signal)
|
||||
None, // Phase 7: no IPv6 endpoint in tests
|
||||
)
|
||||
.await;
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert!(result.is_err(), "both-dead must return Err");
|
||||
// Upper bound: direct 2s timeout + relay 5s fallback + small
|
||||
// slack for scheduling. If this blows, something is looping.
|
||||
assert!(
|
||||
elapsed < Duration::from_secs(10),
|
||||
"race took too long to give up: {:?}",
|
||||
elapsed
|
||||
);
|
||||
}
|
||||
@@ -6,12 +6,12 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use wzp_proto::packet::MediaPacket;
|
||||
use wzp_proto::traits::{MediaTransport, PathQuality};
|
||||
use wzp_proto::{SignalMessage, TransportError};
|
||||
use wzp_proto::{SignalMessage, TransportError, default_signal_version};
|
||||
|
||||
/// A mock transport backed by two mpsc channels (one per direction).
|
||||
///
|
||||
@@ -83,43 +83,68 @@ async fn full_handshake_both_sides_derive_same_session() {
|
||||
|
||||
// Run client and relay handshakes concurrently.
|
||||
let (client_result, relay_result) = tokio::join!(
|
||||
wzp_client::handshake::perform_handshake(client_transport_clone.as_ref(), &client_seed),
|
||||
wzp_client::handshake::perform_handshake(
|
||||
client_transport_clone.as_ref(),
|
||||
&client_seed,
|
||||
None
|
||||
),
|
||||
wzp_relay::handshake::accept_handshake(relay_transport_clone.as_ref(), &relay_seed),
|
||||
);
|
||||
|
||||
let mut client_session = client_result.expect("client handshake should succeed");
|
||||
let (mut relay_session, chosen_profile) =
|
||||
let (mut relay_session, chosen_profile, _caller_fp, _caller_alias) =
|
||||
relay_result.expect("relay handshake should succeed");
|
||||
|
||||
// Verify a profile was chosen.
|
||||
assert_eq!(chosen_profile, wzp_proto::QualityProfile::GOOD);
|
||||
|
||||
// Verify both sides can communicate: client encrypts, relay decrypts.
|
||||
let header = b"test-header";
|
||||
// encrypt/decrypt derive nonces from MediaHeader.seq, so we need valid headers.
|
||||
use wzp_proto::packet::MediaHeader;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
let make_hdr = |seq: u32| {
|
||||
let h = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq.wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut b = Vec::new();
|
||||
h.write_to(&mut b);
|
||||
b
|
||||
};
|
||||
|
||||
let header = make_hdr(0);
|
||||
let plaintext = b"hello from client to relay";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
client_session
|
||||
.encrypt(header, plaintext, &mut ciphertext)
|
||||
.encrypt(&header, plaintext, &mut ciphertext)
|
||||
.expect("client encrypt should succeed");
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
relay_session
|
||||
.decrypt(header, &ciphertext, &mut decrypted)
|
||||
.decrypt(&header, &ciphertext, &mut decrypted)
|
||||
.expect("relay decrypt should succeed");
|
||||
|
||||
assert_eq!(&decrypted[..], plaintext);
|
||||
|
||||
// Verify reverse direction: relay encrypts, client decrypts.
|
||||
let header2 = make_hdr(0); // relay's send_seq starts at 0
|
||||
let plaintext2 = b"hello from relay to client";
|
||||
let mut ciphertext2 = Vec::new();
|
||||
relay_session
|
||||
.encrypt(header, plaintext2, &mut ciphertext2)
|
||||
.encrypt(&header2, plaintext2, &mut ciphertext2)
|
||||
.expect("relay encrypt should succeed");
|
||||
|
||||
let mut decrypted2 = Vec::new();
|
||||
client_session
|
||||
.decrypt(header, &ciphertext2, &mut decrypted2)
|
||||
.decrypt(&header2, &ciphertext2, &mut decrypted2)
|
||||
.expect("client decrypt should succeed");
|
||||
|
||||
assert_eq!(&decrypted2[..], plaintext2);
|
||||
@@ -147,10 +172,14 @@ async fn handshake_rejects_tampered_signature() {
|
||||
let bad_signature = kx.sign(b"wrong-data-intentionally");
|
||||
|
||||
let offer = SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub,
|
||||
ephemeral_pub,
|
||||
signature: bad_signature,
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
};
|
||||
client_transport_clone
|
||||
.send_signal(&offer)
|
||||
@@ -174,3 +203,42 @@ async fn handshake_rejects_tampered_signature() {
|
||||
Ok(_) => panic!("relay should reject tampered signature"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_receives_protocol_version_mismatch() {
|
||||
let (client_transport, relay_transport) = MockTransport::pair();
|
||||
|
||||
let client_seed = [0xAA_u8; 32];
|
||||
|
||||
// Spawn a fake relay that sends ProtocolVersionMismatch.
|
||||
let relay_clone = Arc::clone(&relay_transport);
|
||||
tokio::spawn(async move {
|
||||
// Wait for the client's CallOffer.
|
||||
let offer = relay_clone.recv_signal().await.unwrap().unwrap();
|
||||
assert!(matches!(offer, SignalMessage::CallOffer { .. }));
|
||||
|
||||
// Respond with ProtocolVersionMismatch.
|
||||
let mismatch = SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
|
||||
server_supported: vec![3],
|
||||
},
|
||||
call_id: None,
|
||||
};
|
||||
relay_clone.send_signal(&mismatch).await.unwrap();
|
||||
});
|
||||
|
||||
let result =
|
||||
wzp_client::handshake::perform_handshake(client_transport.as_ref(), &client_seed, None)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(wzp_client::handshake::HandshakeError::ProtocolVersionMismatch {
|
||||
server_supported,
|
||||
}) => {
|
||||
assert_eq!(server_supported, vec![3]);
|
||||
}
|
||||
Err(other) => panic!("expected ProtocolVersionMismatch, got: {other:?}"),
|
||||
Ok(_) => panic!("expected handshake to fail with ProtocolVersionMismatch"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +83,12 @@ fn long_session_no_drift() {
|
||||
println!(
|
||||
"long_session_no_drift: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
stats.underruns,
|
||||
stats.overruns,
|
||||
stats.current_depth,
|
||||
stats.max_depth_seen,
|
||||
stats.packets_late,
|
||||
stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 1 decode per tick over 3000 ticks, we expect ~3000 decoded frames
|
||||
@@ -123,7 +127,7 @@ fn long_session_with_simulated_loss() {
|
||||
|
||||
for (j, pkt) in batch.into_iter().enumerate() {
|
||||
// Drop every 20th *source* (non-repair) packet to simulate ~5% loss.
|
||||
if !pkt.header.is_repair && i % 20 == 0 && j == 0 {
|
||||
if !pkt.header.is_repair() && i % 20 == 0 && j == 0 {
|
||||
continue; // drop this packet
|
||||
}
|
||||
decoder.ingest(pkt);
|
||||
@@ -139,8 +143,12 @@ fn long_session_with_simulated_loss() {
|
||||
println!(
|
||||
"long_session_with_simulated_loss: decoded={frames_decoded}/{TOTAL_FRAMES}, \
|
||||
underruns={}, overruns={}, depth={}, max_depth={}, late={}, lost={}",
|
||||
stats.underruns, stats.overruns, stats.current_depth, stats.max_depth_seen,
|
||||
stats.packets_late, stats.packets_lost,
|
||||
stats.underruns,
|
||||
stats.overruns,
|
||||
stats.current_depth,
|
||||
stats.max_depth_seen,
|
||||
stats.packets_late,
|
||||
stats.packets_lost,
|
||||
);
|
||||
|
||||
// With 5% artificial loss + FEC recovery + PLC, we should still get >90% decoded.
|
||||
@@ -150,6 +158,65 @@ fn long_session_with_simulated_loss() {
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that `MediaHeader::timestamp` continues monotonically across
|
||||
/// rekey boundaries. Rekey is a crypto-layer operation (key material
|
||||
/// rotation) and must not reset or interfere with framing state.
|
||||
///
|
||||
/// We simulate a 3000-frame session with two conceptual rekeys at frames
|
||||
/// 1000 and 2000. The encoder's timestamp counter must advance
|
||||
/// monotonically throughout.
|
||||
#[test]
|
||||
fn rekey_timestamp_monotonic() {
|
||||
let config = test_config();
|
||||
let mut encoder = CallEncoder::new(&config);
|
||||
|
||||
let mut timestamps = Vec::new();
|
||||
|
||||
// Phase 1: before first rekey
|
||||
for i in 0..1000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: between first and second rekey
|
||||
for i in 1000..2000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: after second rekey
|
||||
for i in 2000..3000 {
|
||||
let pcm = sine_frame(i);
|
||||
let packets = encoder.encode_frame(&pcm).expect("encode");
|
||||
for pkt in packets {
|
||||
timestamps.push(pkt.header.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// Assert strict monotonicity (non-decreasing) across all three phases.
|
||||
for window in timestamps.windows(2) {
|
||||
assert!(
|
||||
window[1] >= window[0],
|
||||
"timestamp not monotonic across rekey boundary: {} -> {}",
|
||||
window[0],
|
||||
window[1]
|
||||
);
|
||||
}
|
||||
|
||||
// Sanity: we should have collected at least 3000 timestamps.
|
||||
assert!(
|
||||
timestamps.len() >= 3000,
|
||||
"expected >= 3000 timestamps, got {}",
|
||||
timestamps.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that the jitter buffer's decoded-frame count is consistent with its
|
||||
/// own internal statistics over a long session.
|
||||
#[test]
|
||||
|
||||
@@ -10,8 +10,17 @@ description = "WarzonePhone audio codec layer — Opus + Codec2 encoding/decodin
|
||||
wzp-proto = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Opus bindings
|
||||
audiopus = { workspace = true }
|
||||
# Opus bindings — libopus 1.5.2.
|
||||
# opusic-c for the encoder (set_dred_duration lives here in Phase 1).
|
||||
# opusic-sys for the decoder — we wrap the raw *mut OpusDecoder ourselves
|
||||
# because opusic-c::Decoder.inner is pub(crate), blocking the unified
|
||||
# decoder + DRED path we need in Phase 3.
|
||||
opusic-c = { workspace = true }
|
||||
opusic-sys = { workspace = true }
|
||||
|
||||
# Zero-cost slice reinterpretation for the i16 ↔ u16 boundary between
|
||||
# our PCM buffers and opusic-c's encode API.
|
||||
bytemuck = { workspace = true }
|
||||
|
||||
# Pure-Rust Codec2 implementation
|
||||
codec2 = { workspace = true }
|
||||
|
||||
@@ -116,6 +116,14 @@ impl AudioEncoder for AdaptiveEncoder {
|
||||
fn set_dtx(&mut self, enabled: bool) {
|
||||
self.opus.set_dtx(enabled);
|
||||
}
|
||||
|
||||
fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||
self.opus.set_expected_loss(loss_pct);
|
||||
}
|
||||
|
||||
fn set_dred_duration(&mut self, frames: u8) {
|
||||
self.opus.set_dred_duration(frames);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── AdaptiveDecoder ─────────────────────────────────────────────────────────
|
||||
@@ -199,6 +207,27 @@ impl AdaptiveDecoder {
|
||||
fn codec2_frame_samples(&self) -> usize {
|
||||
self.codec2.frame_samples()
|
||||
}
|
||||
|
||||
/// Reconstruct a lost frame from a previously parsed DRED state.
|
||||
///
|
||||
/// Phase 3b entry point for gap reconstruction. Dispatches to the
|
||||
/// inner Opus decoder when active. Returns an error if the active
|
||||
/// codec is Codec2 — DRED is libopus-only and has no Codec2 equivalent,
|
||||
/// so callers must fall back to classical PLC on Codec2 tiers.
|
||||
pub fn reconstruct_from_dred(
|
||||
&mut self,
|
||||
state: &crate::dred_ffi::DredState,
|
||||
offset_samples: i32,
|
||||
output: &mut [i16],
|
||||
) -> Result<usize, CodecError> {
|
||||
if is_codec2(self.active) {
|
||||
return Err(CodecError::DecodeFailed(
|
||||
"DRED reconstruction is Opus-only; Codec2 must use classical PLC".into(),
|
||||
));
|
||||
}
|
||||
self.opus
|
||||
.reconstruct_from_dred(state, offset_samples, output)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -1,53 +1,127 @@
|
||||
//! Acoustic Echo Cancellation using NLMS adaptive filter.
|
||||
//! Processes 480-sample (10ms) sub-frames at 48kHz.
|
||||
//! Acoustic Echo Cancellation — delay-compensated leaky NLMS with
|
||||
//! Geigel double-talk detection.
|
||||
//!
|
||||
//! Key insight: on a laptop, the round-trip audio latency (playout → speaker
|
||||
//! → air → mic → capture) is 30–50ms. The far-end reference must be delayed
|
||||
//! by this amount so the adaptive filter models the *echo path*, not the
|
||||
//! *system delay + echo path*.
|
||||
//!
|
||||
//! The leaky coefficient decay prevents the filter from diverging when the
|
||||
//! echo path changes (e.g. hand near laptop) or when the delay estimate
|
||||
//! is slightly off.
|
||||
|
||||
/// NLMS (Normalized Least Mean Squares) adaptive filter echo canceller.
|
||||
///
|
||||
/// Removes acoustic echo by modelling the echo path between the far-end
|
||||
/// (speaker) signal and the near-end (microphone) signal, then subtracting
|
||||
/// the estimated echo from the near-end in real time.
|
||||
/// Delay-compensated leaky NLMS echo canceller with Geigel DTD.
|
||||
pub struct EchoCanceller {
|
||||
filter_coeffs: Vec<f32>,
|
||||
// --- Adaptive filter ---
|
||||
filter: Vec<f32>,
|
||||
filter_len: usize,
|
||||
far_end_buf: Vec<f32>,
|
||||
far_end_pos: usize,
|
||||
/// Circular buffer of far-end reference samples (after delay).
|
||||
far_buf: Vec<f32>,
|
||||
far_pos: usize,
|
||||
/// NLMS step size.
|
||||
mu: f32,
|
||||
/// Leakage factor: coefficients are multiplied by (1 - leak) each frame.
|
||||
/// Prevents unbounded growth / divergence. 0.0001 is gentle.
|
||||
leak: f32,
|
||||
enabled: bool,
|
||||
|
||||
// --- Delay buffer ---
|
||||
/// Raw far-end samples before delay compensation.
|
||||
delay_ring: Vec<f32>,
|
||||
delay_write: usize,
|
||||
delay_read: usize,
|
||||
/// Delay in samples (e.g. 1920 = 40ms at 48kHz).
|
||||
delay_samples: usize,
|
||||
/// Capacity of the delay ring.
|
||||
delay_cap: usize,
|
||||
|
||||
// --- Double-talk detection (Geigel) ---
|
||||
/// Peak far-end level over the last filter_len samples.
|
||||
far_peak: f32,
|
||||
/// Geigel threshold: if |near| > threshold * far_peak, assume double-talk.
|
||||
geigel_threshold: f32,
|
||||
/// Holdover counter: keep DTD active for a few frames after detection.
|
||||
dtd_holdover: u32,
|
||||
dtd_hold_frames: u32,
|
||||
}
|
||||
|
||||
impl EchoCanceller {
|
||||
/// Create a new echo canceller.
|
||||
///
|
||||
/// * `sample_rate` — typically 48000
|
||||
/// * `filter_ms` — echo-tail length in milliseconds (e.g. 100 for 100 ms)
|
||||
/// * `filter_ms` — echo-tail length in milliseconds (60ms recommended)
|
||||
/// * `delay_ms` — far-end delay compensation in milliseconds (40ms for laptops)
|
||||
pub fn new(sample_rate: u32, filter_ms: u32) -> Self {
|
||||
Self::with_delay(sample_rate, filter_ms, 40)
|
||||
}
|
||||
|
||||
pub fn with_delay(sample_rate: u32, filter_ms: u32, delay_ms: u32) -> Self {
|
||||
let filter_len = (sample_rate as usize) * (filter_ms as usize) / 1000;
|
||||
let delay_samples = (sample_rate as usize) * (delay_ms as usize) / 1000;
|
||||
// Delay ring must hold at least delay_samples + one frame (960) of headroom.
|
||||
let delay_cap = delay_samples + (sample_rate as usize / 10); // +100ms headroom
|
||||
Self {
|
||||
filter_coeffs: vec![0.0f32; filter_len],
|
||||
filter: vec![0.0; filter_len],
|
||||
filter_len,
|
||||
far_end_buf: vec![0.0f32; filter_len],
|
||||
far_end_pos: 0,
|
||||
far_buf: vec![0.0; filter_len],
|
||||
far_pos: 0,
|
||||
mu: 0.01,
|
||||
leak: 0.0001,
|
||||
enabled: true,
|
||||
|
||||
delay_ring: vec![0.0; delay_cap],
|
||||
delay_write: 0,
|
||||
delay_read: 0,
|
||||
delay_samples,
|
||||
delay_cap,
|
||||
|
||||
far_peak: 0.0,
|
||||
geigel_threshold: 0.7,
|
||||
dtd_holdover: 0,
|
||||
dtd_hold_frames: 5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed far-end (speaker/playback) samples into the circular buffer.
|
||||
///
|
||||
/// Must be called with the audio that was played out through the speaker
|
||||
/// *before* the corresponding near-end frame is processed.
|
||||
/// Feed far-end (speaker) samples. These go into the delay buffer first;
|
||||
/// once enough samples have accumulated, they are released to the filter's
|
||||
/// circular buffer with the correct delay offset.
|
||||
pub fn feed_farend(&mut self, farend: &[i16]) {
|
||||
// Write raw samples into the delay ring.
|
||||
for &s in farend {
|
||||
self.far_end_buf[self.far_end_pos] = s as f32;
|
||||
self.far_end_pos = (self.far_end_pos + 1) % self.filter_len;
|
||||
self.delay_ring[self.delay_write % self.delay_cap] = s as f32;
|
||||
self.delay_write += 1;
|
||||
}
|
||||
|
||||
// Release delayed samples to the filter's far-end buffer.
|
||||
while self.delay_available() >= 1 {
|
||||
let sample = self.delay_ring[self.delay_read % self.delay_cap];
|
||||
self.delay_read += 1;
|
||||
|
||||
self.far_buf[self.far_pos] = sample;
|
||||
self.far_pos = (self.far_pos + 1) % self.filter_len;
|
||||
|
||||
// Track peak far-end level for Geigel DTD.
|
||||
let abs_s = sample.abs();
|
||||
if abs_s > self.far_peak {
|
||||
self.far_peak = abs_s;
|
||||
}
|
||||
}
|
||||
|
||||
// Decay far_peak slowly (avoids stale peak from a loud burst long ago).
|
||||
self.far_peak *= 0.9995;
|
||||
}
|
||||
|
||||
/// Number of delayed samples available to release.
|
||||
fn delay_available(&self) -> usize {
|
||||
let buffered = self.delay_write - self.delay_read;
|
||||
if buffered > self.delay_samples {
|
||||
buffered - self.delay_samples
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a near-end (microphone) frame, removing the estimated echo.
|
||||
///
|
||||
/// Returns the echo-return-loss enhancement (ERLE) as a ratio: the RMS of
|
||||
/// the original near-end divided by the RMS of the residual. Values > 1.0
|
||||
/// mean echo was reduced.
|
||||
pub fn process_frame(&mut self, nearend: &mut [i16]) -> f32 {
|
||||
if !self.enabled {
|
||||
return 1.0;
|
||||
@@ -56,85 +130,96 @@ impl EchoCanceller {
|
||||
let n = nearend.len();
|
||||
let fl = self.filter_len;
|
||||
|
||||
// --- Geigel double-talk detection ---
|
||||
// If any near-end sample exceeds threshold * far_peak, assume
|
||||
// the local speaker is active and freeze adaptation.
|
||||
let mut is_doubletalk = self.dtd_holdover > 0;
|
||||
if !is_doubletalk {
|
||||
let threshold_level = self.geigel_threshold * self.far_peak;
|
||||
for &s in nearend.iter() {
|
||||
if (s as f32).abs() > threshold_level && self.far_peak > 100.0 {
|
||||
is_doubletalk = true;
|
||||
self.dtd_holdover = self.dtd_hold_frames;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.dtd_holdover > 0 {
|
||||
self.dtd_holdover -= 1;
|
||||
}
|
||||
|
||||
// Check if far-end is active (otherwise nothing to cancel).
|
||||
let far_active = self.far_peak > 100.0;
|
||||
|
||||
// --- Leaky coefficient decay ---
|
||||
// Applied once per frame for efficiency.
|
||||
let decay = 1.0 - self.leak;
|
||||
for c in self.filter.iter_mut() {
|
||||
*c *= decay;
|
||||
}
|
||||
|
||||
let mut sum_near_sq: f64 = 0.0;
|
||||
let mut sum_err_sq: f64 = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let near_f = nearend[i] as f32;
|
||||
|
||||
// --- estimate echo as dot(coeffs, farend_window) ---
|
||||
// The far-end window for this sample starts at
|
||||
// (far_end_pos - 1 - i) mod filter_len (most recent)
|
||||
// and goes back filter_len samples.
|
||||
// Position of far-end "now" for this near-end sample.
|
||||
let base = (self.far_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
||||
|
||||
// --- Echo estimation: dot(filter, far_end_window) ---
|
||||
let mut echo_est: f32 = 0.0;
|
||||
let mut power: f32 = 0.0;
|
||||
|
||||
// Position of the most-recent far-end sample for this near-end sample.
|
||||
// far_end_pos points to the *next write* position, so the most-recent
|
||||
// sample written is at far_end_pos - 1. We have already called
|
||||
// feed_farend for this block, so the relevant samples are the last
|
||||
// filter_len entries ending just before the current write position,
|
||||
// offset by how far we are into this near-end frame.
|
||||
//
|
||||
// For sample i of the near-end frame, the corresponding far-end
|
||||
// "now" is far_end_pos - n + i (wrapping).
|
||||
// far_end_pos points to next-write, so most recent sample is at
|
||||
// far_end_pos - 1. For the i-th near-end sample we want the
|
||||
// far-end "now" to be at (far_end_pos - n + i). We add fl
|
||||
// repeatedly to avoid underflow on the usize subtraction.
|
||||
let base = (self.far_end_pos + fl * ((n / fl) + 2) + i - n) % fl;
|
||||
|
||||
for k in 0..fl {
|
||||
let fe_idx = (base + fl - k) % fl;
|
||||
let fe = self.far_end_buf[fe_idx];
|
||||
echo_est += self.filter_coeffs[k] * fe;
|
||||
let fe = self.far_buf[fe_idx];
|
||||
echo_est += self.filter[k] * fe;
|
||||
power += fe * fe;
|
||||
}
|
||||
|
||||
let error = near_f - echo_est;
|
||||
|
||||
// --- NLMS coefficient update ---
|
||||
let norm = power + 1.0; // +1 regularisation to avoid div-by-zero
|
||||
let step = self.mu * error / norm;
|
||||
|
||||
for k in 0..fl {
|
||||
let fe_idx = (base + fl - k) % fl;
|
||||
let fe = self.far_end_buf[fe_idx];
|
||||
self.filter_coeffs[k] += step * fe;
|
||||
// --- NLMS adaptation (only when far-end active & no double-talk) ---
|
||||
if far_active && !is_doubletalk && power > 10.0 {
|
||||
let step = self.mu * error / (power + 1.0);
|
||||
for k in 0..fl {
|
||||
let fe_idx = (base + fl - k) % fl;
|
||||
self.filter[k] += step * self.far_buf[fe_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Clamp output
|
||||
let out = error.max(-32768.0).min(32767.0);
|
||||
let out = error.clamp(-32768.0, 32767.0);
|
||||
nearend[i] = out as i16;
|
||||
|
||||
sum_near_sq += (near_f as f64) * (near_f as f64);
|
||||
sum_err_sq += (out as f64) * (out as f64);
|
||||
sum_near_sq += (near_f as f64).powi(2);
|
||||
sum_err_sq += (out as f64).powi(2);
|
||||
}
|
||||
|
||||
// ERLE ratio
|
||||
if sum_err_sq < 1.0 {
|
||||
return 100.0; // near-perfect cancellation
|
||||
100.0
|
||||
} else {
|
||||
(sum_near_sq / sum_err_sq).sqrt() as f32
|
||||
}
|
||||
(sum_near_sq / sum_err_sq).sqrt() as f32
|
||||
}
|
||||
|
||||
/// Enable or disable echo cancellation.
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.enabled = enabled;
|
||||
}
|
||||
|
||||
/// Returns whether echo cancellation is currently enabled.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Reset the adaptive filter to its initial state.
|
||||
///
|
||||
/// Zeroes out all filter coefficients and the far-end circular buffer.
|
||||
pub fn reset(&mut self) {
|
||||
self.filter_coeffs.iter_mut().for_each(|c| *c = 0.0);
|
||||
self.far_end_buf.iter_mut().for_each(|s| *s = 0.0);
|
||||
self.far_end_pos = 0;
|
||||
self.filter.iter_mut().for_each(|c| *c = 0.0);
|
||||
self.far_buf.iter_mut().for_each(|s| *s = 0.0);
|
||||
self.far_pos = 0;
|
||||
self.far_peak = 0.0;
|
||||
self.delay_ring.iter_mut().for_each(|s| *s = 0.0);
|
||||
self.delay_write = 0;
|
||||
self.delay_read = 0;
|
||||
self.dtd_holdover = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,50 +228,40 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn aec_creates_with_correct_filter_len() {
|
||||
let aec = EchoCanceller::new(48000, 100);
|
||||
assert_eq!(aec.filter_len, 4800);
|
||||
assert_eq!(aec.filter_coeffs.len(), 4800);
|
||||
assert_eq!(aec.far_end_buf.len(), 4800);
|
||||
fn creates_with_correct_sizes() {
|
||||
let aec = EchoCanceller::with_delay(48000, 60, 40);
|
||||
assert_eq!(aec.filter_len, 2880); // 60ms @ 48kHz
|
||||
assert_eq!(aec.delay_samples, 1920); // 40ms @ 48kHz
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aec_passthrough_when_disabled() {
|
||||
let mut aec = EchoCanceller::new(48000, 100);
|
||||
fn passthrough_when_disabled() {
|
||||
let mut aec = EchoCanceller::new(48000, 60);
|
||||
aec.set_enabled(false);
|
||||
assert!(!aec.is_enabled());
|
||||
|
||||
let original: Vec<i16> = (0..480).map(|i| (i * 10) as i16).collect();
|
||||
let original: Vec<i16> = (0..960).map(|i| (i * 10) as i16).collect();
|
||||
let mut frame = original.clone();
|
||||
let erle = aec.process_frame(&mut frame);
|
||||
assert_eq!(erle, 1.0);
|
||||
aec.process_frame(&mut frame);
|
||||
assert_eq!(frame, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aec_reset_zeroes_state() {
|
||||
let mut aec = EchoCanceller::new(48000, 10); // short for test speed
|
||||
let farend: Vec<i16> = (0..480).map(|i| ((i * 37) % 1000) as i16).collect();
|
||||
aec.feed_farend(&farend);
|
||||
|
||||
aec.reset();
|
||||
|
||||
assert!(aec.filter_coeffs.iter().all(|&c| c == 0.0));
|
||||
assert!(aec.far_end_buf.iter().all(|&s| s == 0.0));
|
||||
assert_eq!(aec.far_end_pos, 0);
|
||||
fn silence_passthrough() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 30, 0);
|
||||
aec.feed_farend(&vec![0i16; 960]);
|
||||
let mut frame = vec![0i16; 960];
|
||||
aec.process_frame(&mut frame);
|
||||
assert!(frame.iter().all(|&s| s == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aec_reduces_echo_of_known_signal() {
|
||||
// Use a small filter for speed. Feed a known far-end signal, then
|
||||
// present the *same* signal as near-end (perfect echo, no room).
|
||||
// After adaptation the output energy should drop.
|
||||
let filter_ms = 5; // 240 taps at 48 kHz
|
||||
let mut aec = EchoCanceller::new(48000, filter_ms);
|
||||
fn reduces_echo_with_no_delay() {
|
||||
// Simulate: far-end plays, echo arrives at mic attenuated by ~50%
|
||||
// (realistic — speaker to mic on laptop loses volume).
|
||||
let mut aec = EchoCanceller::with_delay(48000, 10, 0);
|
||||
|
||||
// Generate a simple repeating pattern.
|
||||
let frame_len = 480usize;
|
||||
let make_frame = |offset: usize| -> Vec<i16> {
|
||||
let frame_len = 480;
|
||||
let make_tone = |offset: usize| -> Vec<i16> {
|
||||
(0..frame_len)
|
||||
.map(|i| {
|
||||
let t = (offset + i) as f64 / 48000.0;
|
||||
@@ -195,18 +270,16 @@ mod tests {
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Warm up the adaptive filter with several frames.
|
||||
let mut last_erle = 1.0f32;
|
||||
for frame_idx in 0..40 {
|
||||
let farend = make_frame(frame_idx * frame_len);
|
||||
for frame_idx in 0..100 {
|
||||
let farend = make_tone(frame_idx * frame_len);
|
||||
aec.feed_farend(&farend);
|
||||
|
||||
// Near-end = exact copy of far-end (pure echo).
|
||||
let mut nearend = farend.clone();
|
||||
// Near-end = attenuated copy of far-end (echo at ~50% volume).
|
||||
let mut nearend: Vec<i16> = farend.iter().map(|&s| s / 2).collect();
|
||||
last_erle = aec.process_frame(&mut nearend);
|
||||
}
|
||||
|
||||
// After 40 frames the ERLE should be meaningfully > 1.
|
||||
assert!(
|
||||
last_erle > 1.0,
|
||||
"expected ERLE > 1.0 after adaptation, got {last_erle}"
|
||||
@@ -214,15 +287,52 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aec_silence_passthrough() {
|
||||
let mut aec = EchoCanceller::new(48000, 10);
|
||||
// Feed silence far-end
|
||||
aec.feed_farend(&vec![0i16; 480]);
|
||||
// Near-end is silence too
|
||||
let mut frame = vec![0i16; 480];
|
||||
let erle = aec.process_frame(&mut frame);
|
||||
assert!(erle >= 1.0);
|
||||
// Output should still be silence
|
||||
assert!(frame.iter().all(|&s| s == 0));
|
||||
fn preserves_nearend_during_doubletalk() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 30, 0);
|
||||
|
||||
let frame_len = 960;
|
||||
let nearend: Vec<i16> = (0..frame_len)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 48000.0;
|
||||
(10000.0 * (2.0 * std::f64::consts::PI * 440.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Feed silence as far-end (no echo source).
|
||||
aec.feed_farend(&vec![0i16; frame_len]);
|
||||
|
||||
let mut frame = nearend.clone();
|
||||
aec.process_frame(&mut frame);
|
||||
|
||||
let input_energy: f64 = nearend.iter().map(|&s| (s as f64).powi(2)).sum();
|
||||
let output_energy: f64 = frame.iter().map(|&s| (s as f64).powi(2)).sum();
|
||||
let ratio = output_energy / input_energy;
|
||||
|
||||
assert!(
|
||||
ratio > 0.8,
|
||||
"near-end speech should be preserved, energy ratio = {ratio:.3}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delay_buffer_holds_samples() {
|
||||
let mut aec = EchoCanceller::with_delay(48000, 10, 20);
|
||||
// 20ms delay = 960 samples @ 48kHz.
|
||||
// After feeding, feed_farend auto-drains available samples to far_buf.
|
||||
// So delay_available() is always 0 after feed_farend returns.
|
||||
// Instead, verify far_pos advances only after the delay is filled.
|
||||
|
||||
// Feed 960 samples (= delay amount). No samples released yet.
|
||||
aec.feed_farend(&vec![1i16; 960]);
|
||||
// far_buf should still be all zeros (nothing released).
|
||||
assert!(
|
||||
aec.far_buf.iter().all(|&s| s == 0.0),
|
||||
"nothing should be released yet"
|
||||
);
|
||||
|
||||
// Feed 480 more. 480 should be released to far_buf.
|
||||
aec.feed_farend(&vec![2i16; 480]);
|
||||
let non_zero = aec.far_buf.iter().filter(|&&s| s != 0.0).count();
|
||||
assert!(non_zero > 0, "samples should have been released to far_buf");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,12 +24,12 @@ impl AutoGainControl {
|
||||
/// Create a new AGC with sensible VoIP defaults.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
target_rms: 3000.0, // ~-20 dBFS for i16
|
||||
target_rms: 3000.0, // ~-20 dBFS for i16
|
||||
current_gain: 1.0,
|
||||
min_gain: 0.5,
|
||||
max_gain: 32.0,
|
||||
attack_alpha: 0.3, // fast attack
|
||||
release_alpha: 0.02, // slow release
|
||||
attack_alpha: 0.3, // fast attack
|
||||
release_alpha: 0.02, // slow release
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
@@ -211,9 +211,6 @@ mod tests {
|
||||
fn agc_gain_db_at_unity() {
|
||||
let agc = AutoGainControl::new();
|
||||
let db = agc.current_gain_db();
|
||||
assert!(
|
||||
db.abs() < 0.01,
|
||||
"expected ~0 dB at unity gain, got {db}"
|
||||
);
|
||||
assert!(db.abs() < 0.01, "expected ~0 dB at unity gain, got {db}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +99,11 @@ mod tests {
|
||||
}
|
||||
let original_len = pcm.len();
|
||||
ns.process(&mut pcm);
|
||||
assert_eq!(pcm.len(), original_len, "output length must match input length");
|
||||
assert_eq!(
|
||||
pcm.len(),
|
||||
original_len,
|
||||
"output length must match input length"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
583
crates/wzp-codec/src/dred_ffi.rs
Normal file
583
crates/wzp-codec/src/dred_ffi.rs
Normal file
@@ -0,0 +1,583 @@
|
||||
//! Raw opusic-sys FFI wrappers for libopus 1.5.2 decoder + DRED reconstruction.
|
||||
//!
|
||||
//! # Why this module exists
|
||||
//!
|
||||
//! We cannot use `opusic_c::Decoder` because its inner `*mut OpusDecoder`
|
||||
//! pointer is `pub(crate)` — not reachable from outside the opusic-c crate.
|
||||
//! Phase 3 of the DRED integration needs to hand that same pointer to
|
||||
//! `opus_decoder_dred_decode`, and running two parallel decoders (one from
|
||||
//! opusic-c for normal audio, another from opusic-sys for DRED) would cause
|
||||
//! the DRED-only decoder's internal state to drift out of sync with the
|
||||
//! audio stream because it would not see normal decode calls.
|
||||
//!
|
||||
//! The fix is to own the raw decoder ourselves and use the same handle for
|
||||
//! both normal decode AND DRED reconstruction. This module is the single
|
||||
//! owner of `*mut OpusDecoder`, `*mut OpusDREDDecoder`, and `*mut OpusDRED`
|
||||
//! in the WZP workspace.
|
||||
//!
|
||||
//! # Phase 3a scope
|
||||
//!
|
||||
//! Phase 0 added `DecoderHandle` (normal decode). Phase 3a adds:
|
||||
//! - [`DredDecoderHandle`] — wraps `*mut OpusDREDDecoder` for parsing DRED
|
||||
//! side-channel data out of arriving Opus packets.
|
||||
//! - [`DredState`] — wraps `*mut OpusDRED` (a fixed 10,592-byte buffer
|
||||
//! allocated by libopus) that holds parsed DRED state between the parse
|
||||
//! and reconstruct steps.
|
||||
//! - [`DredDecoderHandle::parse_into`] — wraps `opus_dred_parse`.
|
||||
//! - [`DecoderHandle::reconstruct_from_dred`] — wraps `opus_decoder_dred_decode`.
|
||||
//!
|
||||
//! The pattern is: on every arriving Opus packet, the receiver calls
|
||||
//! `parse_into` with a reusable `DredState`, then stores (seq, state_clone)
|
||||
//! in a ring. On detected loss, the receiver computes the offset from the
|
||||
//! freshest reachable DRED state and calls `reconstruct_from_dred` to
|
||||
//! synthesize the missing audio.
|
||||
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use opusic_sys::{
|
||||
OPUS_OK, OpusDRED, OpusDREDDecoder, OpusDecoder as RawOpusDecoder, opus_decode,
|
||||
opus_decoder_create, opus_decoder_destroy, opus_decoder_dred_decode, opus_dred_alloc,
|
||||
opus_dred_decoder_create, opus_dred_decoder_destroy, opus_dred_free, opus_dred_parse,
|
||||
};
|
||||
use wzp_proto::CodecError;
|
||||
|
||||
/// libopus operates at 48 kHz for all Opus variants we use.
|
||||
const SAMPLE_RATE_HZ: i32 = 48_000;
|
||||
/// Mono.
|
||||
const CHANNELS: i32 = 1;
|
||||
|
||||
/// Safe owner of a `*mut OpusDecoder` allocated via `opus_decoder_create`.
|
||||
///
|
||||
/// Releases the decoder in `Drop`. All FFI access goes through `&mut self`
|
||||
/// methods, so there is no aliasing or race. The raw pointer is exposed via
|
||||
/// [`Self::as_raw_ptr`] at a crate-internal visibility for the future Phase 3
|
||||
/// DRED reconstruction path — external crates cannot reach it.
|
||||
pub struct DecoderHandle {
|
||||
inner: NonNull<RawOpusDecoder>,
|
||||
}
|
||||
|
||||
impl DecoderHandle {
|
||||
/// Allocate a new Opus decoder at 48 kHz mono.
|
||||
pub fn new() -> Result<Self, CodecError> {
|
||||
let mut error: i32 = OPUS_OK;
|
||||
// SAFETY: opus_decoder_create writes to `error` and returns either a
|
||||
// valid heap pointer or null. We check both before constructing the
|
||||
// NonNull wrapper.
|
||||
let ptr = unsafe { opus_decoder_create(SAMPLE_RATE_HZ, CHANNELS, &mut error) };
|
||||
if error != OPUS_OK {
|
||||
// Even if ptr is non-null on error, libopus contracts guarantee
|
||||
// it is unusable — do not attempt to free it.
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_decoder_create failed: err={error}"
|
||||
)));
|
||||
}
|
||||
let inner = NonNull::new(ptr)
|
||||
.ok_or_else(|| CodecError::DecodeFailed("opus_decoder_create returned null".into()))?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Decode an Opus packet into PCM samples.
|
||||
///
|
||||
/// `pcm` must have enough capacity for the frame (960 for 20 ms, 1920
|
||||
/// for 40 ms at 48 kHz mono). Returns the number of decoded samples
|
||||
/// per channel — for mono streams this equals the total sample count.
|
||||
pub fn decode(&mut self, packet: &[u8], pcm: &mut [i16]) -> Result<usize, CodecError> {
|
||||
if packet.is_empty() {
|
||||
return Err(CodecError::DecodeFailed("empty packet".into()));
|
||||
}
|
||||
if pcm.is_empty() {
|
||||
return Err(CodecError::DecodeFailed("empty output buffer".into()));
|
||||
}
|
||||
// SAFETY: self.inner is a valid *mut OpusDecoder owned by this struct.
|
||||
// `data` / `pcm` are live Rust slices, so their pointers and lengths
|
||||
// are valid for the duration of the call. libopus reads len bytes
|
||||
// from data and writes up to frame_size samples (per channel) to pcm.
|
||||
let n = unsafe {
|
||||
opus_decode(
|
||||
self.inner.as_ptr(),
|
||||
packet.as_ptr(),
|
||||
packet.len() as i32,
|
||||
pcm.as_mut_ptr(),
|
||||
pcm.len() as i32,
|
||||
/* decode_fec = */ 0,
|
||||
)
|
||||
};
|
||||
if n < 0 {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_decode failed: err={n}"
|
||||
)));
|
||||
}
|
||||
Ok(n as usize)
|
||||
}
|
||||
|
||||
/// Generate packet-loss concealment audio for a missing frame.
|
||||
///
|
||||
/// Implemented via `opus_decode` with a null data pointer, per the
|
||||
/// libopus API contract. `pcm` should be sized for the expected frame.
|
||||
pub fn decode_lost(&mut self, pcm: &mut [i16]) -> Result<usize, CodecError> {
|
||||
if pcm.is_empty() {
|
||||
return Err(CodecError::DecodeFailed("empty output buffer".into()));
|
||||
}
|
||||
// SAFETY: same invariants as decode(). libopus documents that passing
|
||||
// a null data pointer with len=0 triggers PLC synthesis into pcm.
|
||||
let n = unsafe {
|
||||
opus_decode(
|
||||
self.inner.as_ptr(),
|
||||
std::ptr::null(),
|
||||
0,
|
||||
pcm.as_mut_ptr(),
|
||||
pcm.len() as i32,
|
||||
/* decode_fec = */ 0,
|
||||
)
|
||||
};
|
||||
if n < 0 {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_decode PLC failed: err={n}"
|
||||
)));
|
||||
}
|
||||
Ok(n as usize)
|
||||
}
|
||||
|
||||
/// Reconstruct audio from a `DredState` into the `output` buffer.
|
||||
///
|
||||
/// `offset_samples` is the sample position (positive, measured backward
|
||||
/// from the packet anchor that produced `state`) where reconstruction
|
||||
/// begins. `output.len()` must match the number of samples to synthesize.
|
||||
///
|
||||
/// The libopus API: `opus_decoder_dred_decode(st, dred, dred_offset, pcm,
|
||||
/// frame_size)` where `dred_offset` is "position of the redundancy to
|
||||
/// decode, in samples before the beginning of the real audio data in the
|
||||
/// packet." Valid values: `0 < offset_samples < state.samples_available()`.
|
||||
///
|
||||
/// Returns the number of samples actually written (should equal
|
||||
/// `output.len()` on success).
|
||||
pub fn reconstruct_from_dred(
|
||||
&mut self,
|
||||
state: &DredState,
|
||||
offset_samples: i32,
|
||||
output: &mut [i16],
|
||||
) -> Result<usize, CodecError> {
|
||||
if output.is_empty() {
|
||||
return Err(CodecError::DecodeFailed(
|
||||
"empty reconstruction output buffer".into(),
|
||||
));
|
||||
}
|
||||
if offset_samples <= 0 {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"DRED offset must be positive (got {offset_samples})"
|
||||
)));
|
||||
}
|
||||
if offset_samples > state.samples_available() {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"DRED offset {offset_samples} exceeds available samples {}",
|
||||
state.samples_available()
|
||||
)));
|
||||
}
|
||||
// SAFETY: self.inner is a valid *mut OpusDecoder, state.inner is a
|
||||
// valid *const OpusDRED populated by a prior parse_into call, and
|
||||
// output is a live mutable slice. libopus reads from dred and writes
|
||||
// exactly frame_size samples (the output.len()) to pcm.
|
||||
let n = unsafe {
|
||||
opus_decoder_dred_decode(
|
||||
self.inner.as_ptr(),
|
||||
state.inner.as_ptr(),
|
||||
offset_samples,
|
||||
output.as_mut_ptr(),
|
||||
output.len() as i32,
|
||||
)
|
||||
};
|
||||
if n < 0 {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_decoder_dred_decode failed: err={n}"
|
||||
)));
|
||||
}
|
||||
Ok(n as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DecoderHandle {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: we own the pointer and no further access happens after
|
||||
// this call because Drop consumes self.
|
||||
unsafe { opus_decoder_destroy(self.inner.as_ptr()) };
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: The underlying OpusDecoder is a plain heap allocation with no
|
||||
// thread-local or lock-free state. It is safe to move between threads
|
||||
// (Send), and all method access is gated by &mut self so Rust's borrow
|
||||
// checker prevents simultaneous access from multiple threads (Sync).
|
||||
unsafe impl Send for DecoderHandle {}
|
||||
unsafe impl Sync for DecoderHandle {}
|
||||
|
||||
// ─── DRED decoder (parser) ──────────────────────────────────────────────────
|
||||
|
||||
/// Safe owner of a `*mut OpusDREDDecoder` allocated via
|
||||
/// `opus_dred_decoder_create`.
|
||||
///
|
||||
/// The DRED decoder is a **separate** libopus object from the regular
|
||||
/// `OpusDecoder`. It's used exclusively for parsing DRED side-channel data
|
||||
/// out of arriving Opus packets via [`Self::parse_into`]. Actual audio
|
||||
/// reconstruction from the parsed state uses the regular `DecoderHandle`
|
||||
/// via [`DecoderHandle::reconstruct_from_dred`].
|
||||
pub struct DredDecoderHandle {
|
||||
inner: NonNull<OpusDREDDecoder>,
|
||||
}
|
||||
|
||||
impl DredDecoderHandle {
|
||||
/// Allocate a new DRED decoder.
|
||||
pub fn new() -> Result<Self, CodecError> {
|
||||
let mut error: i32 = OPUS_OK;
|
||||
// SAFETY: opus_dred_decoder_create writes to `error` and returns
|
||||
// either a valid heap pointer or null. Both are checked.
|
||||
let ptr = unsafe { opus_dred_decoder_create(&mut error) };
|
||||
if error != OPUS_OK {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_dred_decoder_create failed: err={error}"
|
||||
)));
|
||||
}
|
||||
let inner = NonNull::new(ptr).ok_or_else(|| {
|
||||
CodecError::DecodeFailed("opus_dred_decoder_create returned null".into())
|
||||
})?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Parse DRED side-channel data from an Opus packet into `state`.
|
||||
///
|
||||
/// Returns the number of samples of audio history available for
|
||||
/// reconstruction, or 0 if the packet carries no DRED data. Subsequent
|
||||
/// `DecoderHandle::reconstruct_from_dred` calls using this `state` can
|
||||
/// reconstruct any sample position in `(0, samples_available]`.
|
||||
///
|
||||
/// libopus API: `opus_dred_parse(dred_dec, dred, data, len,
|
||||
/// max_dred_samples, sampling_rate, dred_end, defer_processing)`. We
|
||||
/// pass `max_dred_samples = 48000` (1 s at 48 kHz, the DRED maximum),
|
||||
/// `sampling_rate = 48000`, `defer_processing = 0` (process immediately).
|
||||
/// The `dred_end` output is the silence gap at the tail of the DRED
|
||||
/// window; we subtract it from the total offset to give callers the
|
||||
/// truly usable sample count.
|
||||
pub fn parse_into(&mut self, state: &mut DredState, packet: &[u8]) -> Result<i32, CodecError> {
|
||||
if packet.is_empty() {
|
||||
state.samples_available = 0;
|
||||
return Ok(0);
|
||||
}
|
||||
let mut dred_end: i32 = 0;
|
||||
// SAFETY: self.inner is a valid *mut OpusDREDDecoder; state.inner is
|
||||
// a valid *mut OpusDRED allocated via opus_dred_alloc; packet is a
|
||||
// live slice; dred_end is a stack int. libopus reads packet bytes
|
||||
// and writes parsed DRED state into *state.inner.
|
||||
let ret = unsafe {
|
||||
opus_dred_parse(
|
||||
self.inner.as_ptr(),
|
||||
state.inner.as_ptr(),
|
||||
packet.as_ptr(),
|
||||
packet.len() as i32,
|
||||
/* max_dred_samples = */ 48_000, // 1s max per libopus 1.5
|
||||
/* sampling_rate = */ 48_000,
|
||||
&mut dred_end,
|
||||
/* defer_processing = */ 0,
|
||||
)
|
||||
};
|
||||
if ret < 0 {
|
||||
state.samples_available = 0;
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_dred_parse failed: err={ret}"
|
||||
)));
|
||||
}
|
||||
// ret is the positive offset of the first decodable DRED sample,
|
||||
// or 0 if no DRED is present. dred_end is the silence gap at the
|
||||
// tail. The usable sample range is (dred_end, ret], so the count
|
||||
// of usable samples is ret - dred_end. We store `ret` as the max
|
||||
// usable offset — callers should pass dred_offset values in the
|
||||
// range (dred_end, ret] to reconstruct_from_dred. For simplicity
|
||||
// we expose just samples_available = ret and let callers treat
|
||||
// the full window as valid (the silence gap is small and libopus
|
||||
// handles minor boundary cases gracefully).
|
||||
state.samples_available = ret;
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DredDecoderHandle {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: we own the pointer and no further access happens after
|
||||
// this call because Drop consumes self.
|
||||
unsafe { opus_dred_decoder_destroy(self.inner.as_ptr()) };
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: same reasoning as DecoderHandle — heap allocation with no
|
||||
// thread-local state, &mut self access discipline prevents races.
|
||||
unsafe impl Send for DredDecoderHandle {}
|
||||
unsafe impl Sync for DredDecoderHandle {}
|
||||
|
||||
// ─── DRED state buffer ──────────────────────────────────────────────────────
|
||||
|
||||
/// Safe owner of a `*mut OpusDRED` allocated via `opus_dred_alloc`.
|
||||
///
|
||||
/// Holds a fixed-size (10,592-byte per libopus 1.5) buffer that
|
||||
/// `DredDecoderHandle::parse_into` populates from an Opus packet. The state
|
||||
/// is reusable — the caller can call `parse_into` again on the same
|
||||
/// `DredState` to overwrite it with a fresh packet's data.
|
||||
///
|
||||
/// `samples_available` tracks the last-parsed result so reconstruction
|
||||
/// callers don't need to thread the return value separately. A fresh
|
||||
/// state (before any `parse_into`) has `samples_available == 0`.
|
||||
pub struct DredState {
|
||||
inner: NonNull<OpusDRED>,
|
||||
samples_available: i32,
|
||||
}
|
||||
|
||||
impl DredState {
|
||||
/// Allocate a new DRED state buffer.
|
||||
pub fn new() -> Result<Self, CodecError> {
|
||||
let mut error: i32 = OPUS_OK;
|
||||
// SAFETY: opus_dred_alloc writes to `error` and returns either a
|
||||
// valid heap pointer or null.
|
||||
let ptr = unsafe { opus_dred_alloc(&mut error) };
|
||||
if error != OPUS_OK {
|
||||
return Err(CodecError::DecodeFailed(format!(
|
||||
"opus_dred_alloc failed: err={error}"
|
||||
)));
|
||||
}
|
||||
let inner = NonNull::new(ptr)
|
||||
.ok_or_else(|| CodecError::DecodeFailed("opus_dred_alloc returned null".into()))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
samples_available: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// How many samples of audio history this state currently covers.
|
||||
///
|
||||
/// Returns 0 if the state is fresh or the last parse found no DRED
|
||||
/// data. Otherwise returns the positive offset set by the most recent
|
||||
/// `DredDecoderHandle::parse_into` call — the maximum valid
|
||||
/// `offset_samples` value for `DecoderHandle::reconstruct_from_dred`.
|
||||
pub fn samples_available(&self) -> i32 {
|
||||
self.samples_available
|
||||
}
|
||||
|
||||
/// Reset the state to "fresh" without freeing the underlying buffer.
|
||||
/// The next `parse_into` will overwrite the contents.
|
||||
pub fn reset(&mut self) {
|
||||
self.samples_available = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DredState {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: we own the pointer and no further access happens after
|
||||
// this call because Drop consumes self.
|
||||
unsafe { opus_dred_free(self.inner.as_ptr()) };
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: same reasoning as DecoderHandle.
|
||||
unsafe impl Send for DredState {}
|
||||
unsafe impl Sync for DredState {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn decoder_handle_creates_and_drops() {
|
||||
let handle = DecoderHandle::new().expect("decoder create");
|
||||
// Dropping the handle must not panic or leak — validated by miri
|
||||
// and the absence of sanitizer complaints in CI.
|
||||
drop(handle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_lost_produces_full_frame_of_silence_on_cold_start() {
|
||||
let mut handle = DecoderHandle::new().unwrap();
|
||||
// 20 ms @ 48 kHz mono.
|
||||
let mut pcm = vec![0i16; 960];
|
||||
let n = handle.decode_lost(&mut pcm).unwrap();
|
||||
assert_eq!(n, 960);
|
||||
// On a fresh decoder, PLC output is silence (no past audio to extend).
|
||||
assert!(pcm.iter().all(|&s| s == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_empty_packet_errors() {
|
||||
let mut handle = DecoderHandle::new().unwrap();
|
||||
let mut pcm = vec![0i16; 960];
|
||||
let err = handle.decode(&[], &mut pcm);
|
||||
assert!(err.is_err());
|
||||
}
|
||||
|
||||
// ─── Phase 3a — DRED decoder + state ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dred_decoder_handle_creates_and_drops() {
|
||||
let h = DredDecoderHandle::new().expect("dred decoder create");
|
||||
drop(h);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_state_creates_and_drops() {
|
||||
let s = DredState::new().expect("dred state alloc");
|
||||
assert_eq!(s.samples_available(), 0);
|
||||
drop(s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_state_reset_zeroes_counter() {
|
||||
let mut s = DredState::new().unwrap();
|
||||
s.samples_available = 480; // pretend a parse populated it
|
||||
assert_eq!(s.samples_available(), 480);
|
||||
s.reset();
|
||||
assert_eq!(s.samples_available(), 0);
|
||||
}
|
||||
|
||||
/// Phase 3a end-to-end: encode a DRED-enabled stream, parse state out
|
||||
/// of packets, and reconstruct audio at a past offset. Validates the
|
||||
/// full parse → reconstruct pipeline against a real libopus 1.5.2
|
||||
/// encoder so we catch FFI-layer bugs early.
|
||||
#[test]
|
||||
fn dred_parse_and_reconstruct_roundtrip() {
|
||||
use crate::opus_enc::OpusEncoder;
|
||||
use wzp_proto::{AudioEncoder, QualityProfile};
|
||||
|
||||
// Encoder with DRED at Opus 24k / 200 ms duration (Phase 1 default
|
||||
// for GOOD profile). The loss floor is 5% per Phase 1.
|
||||
let mut enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
|
||||
// Decode-side handles.
|
||||
let mut dec = DecoderHandle::new().unwrap();
|
||||
let mut dred_dec = DredDecoderHandle::new().unwrap();
|
||||
let mut state = DredState::new().unwrap();
|
||||
|
||||
// Generate 60 frames (1.2 s) of a voice-like 300 Hz sine wave so
|
||||
// the encoder's DRED emitter has real content to encode rather
|
||||
// than compressing silence.
|
||||
let frame_len = 960usize; // 20 ms @ 48 kHz
|
||||
let make_frame = |offset: usize| -> Vec<i16> {
|
||||
(0..frame_len)
|
||||
.map(|i| {
|
||||
let t = (offset + i) as f64 / 48_000.0;
|
||||
(8000.0 * (2.0 * std::f64::consts::PI * 300.0 * t).sin()) as i16
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Track the freshest packet that carried non-zero DRED state.
|
||||
let mut best_samples_available = 0;
|
||||
let mut best_packet: Option<Vec<u8>> = None;
|
||||
|
||||
for frame_idx in 0..60 {
|
||||
let pcm = make_frame(frame_idx * frame_len);
|
||||
let mut encoded = vec![0u8; 512];
|
||||
let n = enc.encode(&pcm, &mut encoded).unwrap();
|
||||
encoded.truncate(n);
|
||||
|
||||
// Run the packet through the normal decode path so dec's
|
||||
// internal state mirrors the full stream — this is necessary
|
||||
// for DRED reconstruction to produce meaningful output.
|
||||
let mut decoded = vec![0i16; frame_len];
|
||||
dec.decode(&encoded, &mut decoded).unwrap();
|
||||
|
||||
// Parse DRED state out of the same packet. Early packets may
|
||||
// have samples_available == 0 while the DRED encoder warms up;
|
||||
// later packets should carry the full window.
|
||||
match dred_dec.parse_into(&mut state, &encoded) {
|
||||
Ok(available) => {
|
||||
if available > best_samples_available {
|
||||
best_samples_available = available;
|
||||
best_packet = Some(encoded.clone());
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("parse_into errored unexpectedly: {e:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// By the time we're 60 frames in, DRED should have emitted data.
|
||||
assert!(
|
||||
best_samples_available > 0,
|
||||
"DRED emitted zero samples across 60 frames — the encoder isn't \
|
||||
producing DRED bytes (check set_dred_duration and packet_loss floor)"
|
||||
);
|
||||
|
||||
// Parse the best packet into a fresh state and reconstruct some
|
||||
// audio from somewhere inside its DRED window. We use frame_len/2
|
||||
// as the offset to pick a point squarely inside the reconstructable
|
||||
// range rather than at an edge.
|
||||
let packet = best_packet.expect("at least one packet had DRED state");
|
||||
let mut fresh_state = DredState::new().unwrap();
|
||||
let available = dred_dec.parse_into(&mut fresh_state, &packet).unwrap();
|
||||
assert!(available > 0, "re-parse of known-good packet returned 0");
|
||||
|
||||
// Need a decoder that's in the right state to reconstruct — rewind
|
||||
// by creating a fresh one and feeding it the same stream up to the
|
||||
// point of the best packet. Simpler: just use a fresh decoder and
|
||||
// accept that the reconstructed samples may not be phase-matched.
|
||||
// The test here only asserts *non-silent energy*, not signal fidelity.
|
||||
let mut recon_dec = DecoderHandle::new().unwrap();
|
||||
// Warm up the decoder with one frame so its internal state is valid.
|
||||
let warmup_pcm = vec![0i16; frame_len];
|
||||
let warmup_encoded = {
|
||||
let mut warmup_enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
let mut buf = vec![0u8; 512];
|
||||
let n = warmup_enc.encode(&warmup_pcm, &mut buf).unwrap();
|
||||
buf.truncate(n);
|
||||
buf
|
||||
};
|
||||
let mut throwaway = vec![0i16; frame_len];
|
||||
let _ = recon_dec.decode(&warmup_encoded, &mut throwaway);
|
||||
|
||||
// Reconstruct 20 ms from some position inside the DRED window.
|
||||
let offset = (available / 2).max(480).min(available);
|
||||
let mut recon_pcm = vec![0i16; frame_len];
|
||||
let n = recon_dec
|
||||
.reconstruct_from_dred(&fresh_state, offset, &mut recon_pcm)
|
||||
.expect("reconstruct_from_dred failed");
|
||||
assert_eq!(n, frame_len);
|
||||
|
||||
// Energy check: reconstructed audio should not be all zeros. A
|
||||
// loose threshold — the DRED reconstruction won't be phase-matched
|
||||
// to our sine wave because we fed a cold decoder only one warmup
|
||||
// frame, but it should still produce non-silent speech-like output
|
||||
// since the DRED state was parsed from real speech content.
|
||||
let energy: u64 = recon_pcm
|
||||
.iter()
|
||||
.map(|&s| (s as i32).unsigned_abs() as u64)
|
||||
.sum();
|
||||
assert!(
|
||||
energy > 0,
|
||||
"reconstructed audio has zero total energy — DRED reconstruction produced silence"
|
||||
);
|
||||
}
|
||||
|
||||
/// A second roundtrip variant: offset too large errors cleanly rather
|
||||
/// than crashing the FFI.
|
||||
#[test]
|
||||
fn reconstruct_with_out_of_range_offset_errors() {
|
||||
let mut dec = DecoderHandle::new().unwrap();
|
||||
let state = DredState::new().unwrap();
|
||||
// state has samples_available == 0 (fresh), so any positive offset
|
||||
// should be out of range.
|
||||
let mut out = vec![0i16; 960];
|
||||
let err = dec.reconstruct_from_dred(&state, 480, &mut out);
|
||||
assert!(err.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconstruct_with_zero_offset_errors() {
|
||||
let mut dec = DecoderHandle::new().unwrap();
|
||||
let state = DredState::new().unwrap();
|
||||
let mut out = vec![0i16; 960];
|
||||
let err = dec.reconstruct_from_dred(&state, 0, &mut out);
|
||||
assert!(err.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_parse_empty_packet_returns_zero() {
|
||||
let mut dred_dec = DredDecoderHandle::new().unwrap();
|
||||
let mut state = DredState::new().unwrap();
|
||||
let result = dred_dec.parse_into(&mut state, &[]).unwrap();
|
||||
assert_eq!(result, 0);
|
||||
assert_eq!(state.samples_available(), 0);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ pub mod agc;
|
||||
pub mod codec2_dec;
|
||||
pub mod codec2_enc;
|
||||
pub mod denoise;
|
||||
pub mod dred_ffi;
|
||||
pub mod opus_dec;
|
||||
pub mod opus_enc;
|
||||
pub mod resample;
|
||||
@@ -27,15 +28,32 @@ pub use denoise::NoiseSupressor;
|
||||
pub use silence::{ComfortNoise, SilenceDetector};
|
||||
pub use wzp_proto::{AudioDecoder, AudioEncoder, CodecId, QualityProfile};
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
/// Global verbose-logging flag for DRED. Off by default — when enabled
|
||||
/// (via the GUI debug toggle wired through Tauri), the encoder logs its
|
||||
/// DRED config + libopus version, and the recv path logs every DRED
|
||||
/// reconstruction, classical PLC fill, and parse heartbeat. Off in
|
||||
/// "normal" mode keeps logcat clean.
|
||||
static DRED_VERBOSE_LOGS: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Returns whether DRED verbose logging is currently enabled.
|
||||
#[inline]
|
||||
pub fn dred_verbose_logs() -> bool {
|
||||
DRED_VERBOSE_LOGS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Enable/disable DRED verbose logging at runtime.
|
||||
pub fn set_dred_verbose_logs(enabled: bool) {
|
||||
DRED_VERBOSE_LOGS.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Create an adaptive encoder starting at the given quality profile.
|
||||
///
|
||||
/// The returned encoder accepts 48 kHz mono PCM regardless of the active
|
||||
/// codec; resampling is handled internally when Codec2 is selected.
|
||||
pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
||||
Box::new(
|
||||
AdaptiveEncoder::new(profile)
|
||||
.expect("failed to create adaptive encoder"),
|
||||
)
|
||||
Box::new(AdaptiveEncoder::new(profile).expect("failed to create adaptive encoder"))
|
||||
}
|
||||
|
||||
/// Create an adaptive decoder starting at the given quality profile.
|
||||
@@ -43,10 +61,7 @@ pub fn create_encoder(profile: QualityProfile) -> Box<dyn AudioEncoder> {
|
||||
/// The returned decoder always produces 48 kHz mono PCM; upsampling from
|
||||
/// Codec2's native 8 kHz is handled internally.
|
||||
pub fn create_decoder(profile: QualityProfile) -> Box<dyn AudioDecoder> {
|
||||
Box::new(
|
||||
AdaptiveDecoder::new(profile)
|
||||
.expect("failed to create adaptive decoder"),
|
||||
)
|
||||
Box::new(AdaptiveDecoder::new(profile).expect("failed to create adaptive decoder"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -61,6 +76,10 @@ mod codec2_tests {
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: wzp_proto::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +208,10 @@ mod codec2_tests {
|
||||
|
||||
let mut pcm_out_c2 = vec![0i16; 1920];
|
||||
let samples_c2 = dec.decode(&encoded_c2[..n_c2], &mut pcm_out_c2).unwrap();
|
||||
assert_eq!(samples_c2, 1920, "should get 1920 samples at 48kHz after upsample");
|
||||
assert_eq!(
|
||||
samples_c2, 1920,
|
||||
"should get 1920 samples at 48kHz after upsample"
|
||||
);
|
||||
|
||||
// Step 3: Switch back to Opus.
|
||||
enc.set_profile(QualityProfile::GOOD).unwrap();
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
//! Opus decoder wrapping the `audiopus` crate.
|
||||
//! Opus decoder built on top of the raw opusic-sys `DecoderHandle`.
|
||||
//!
|
||||
//! Phase 0 of the DRED integration: we went straight to a custom
|
||||
//! `DecoderHandle` instead of `opusic_c::Decoder` because the latter's
|
||||
//! inner pointer is `pub(crate)` and we need to reach it in Phase 3 for
|
||||
//! `opus_decoder_dred_decode`. See `dred_ffi.rs` for the rationale and
|
||||
//! `docs/PRD-dred-integration.md` for the full plan.
|
||||
|
||||
use audiopus::coder::Decoder;
|
||||
use audiopus::{Channels, MutSignals, SampleRate};
|
||||
use audiopus::packet::Packet;
|
||||
use crate::dred_ffi::{DecoderHandle, DredState};
|
||||
use wzp_proto::{AudioDecoder, CodecError, CodecId, QualityProfile};
|
||||
|
||||
/// Opus decoder implementing `AudioDecoder`.
|
||||
/// Opus decoder implementing [`AudioDecoder`].
|
||||
///
|
||||
/// Operates at 48 kHz mono output.
|
||||
/// Operates at 48 kHz mono output. 20 ms and 40 ms frames supported via
|
||||
/// the active `QualityProfile`. Behavior is intentionally identical to
|
||||
/// the pre-swap audiopus-based decoder at this phase — DRED reconstruction
|
||||
/// lands in Phase 3.
|
||||
pub struct OpusDecoder {
|
||||
inner: Decoder,
|
||||
inner: DecoderHandle,
|
||||
codec_id: CodecId,
|
||||
frame_duration_ms: u8,
|
||||
}
|
||||
|
||||
// SAFETY: Same reasoning as OpusEncoder — exclusive access via &mut self.
|
||||
unsafe impl Sync for OpusDecoder {}
|
||||
|
||||
impl OpusDecoder {
|
||||
/// Create a new Opus decoder for the given quality profile.
|
||||
pub fn new(profile: QualityProfile) -> Result<Self, CodecError> {
|
||||
let decoder = Decoder::new(SampleRate::Hz48000, Channels::Mono)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("opus decoder init: {e}")))?;
|
||||
|
||||
let inner = DecoderHandle::new()?;
|
||||
Ok(Self {
|
||||
inner: decoder,
|
||||
inner,
|
||||
codec_id: profile.codec,
|
||||
frame_duration_ms: profile.frame_duration_ms,
|
||||
})
|
||||
@@ -34,6 +36,24 @@ impl OpusDecoder {
|
||||
pub fn frame_samples(&self) -> usize {
|
||||
(48_000 * self.frame_duration_ms as usize) / 1000
|
||||
}
|
||||
|
||||
/// Reconstruct a lost frame from a previously parsed `DredState`.
|
||||
///
|
||||
/// Phase 3b entry point: callers (CallDecoder / engine.rs) use this to
|
||||
/// synthesize audio for gaps detected by the jitter buffer when DRED
|
||||
/// side-channel state from a later-arriving packet covers the gap's
|
||||
/// sample offset. `offset_samples` is measured backward from the anchor
|
||||
/// packet that produced `state`. See `DecoderHandle::reconstruct_from_dred`
|
||||
/// for the full semantics.
|
||||
pub fn reconstruct_from_dred(
|
||||
&mut self,
|
||||
state: &DredState,
|
||||
offset_samples: i32,
|
||||
output: &mut [i16],
|
||||
) -> Result<usize, CodecError> {
|
||||
self.inner
|
||||
.reconstruct_from_dred(state, offset_samples, output)
|
||||
}
|
||||
}
|
||||
|
||||
impl AudioDecoder for OpusDecoder {
|
||||
@@ -45,15 +65,7 @@ impl AudioDecoder for OpusDecoder {
|
||||
pcm.len()
|
||||
)));
|
||||
}
|
||||
let packet = Packet::try_from(encoded)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("invalid packet: {e}")))?;
|
||||
let signals = MutSignals::try_from(pcm)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("output signals: {e}")))?;
|
||||
let n = self
|
||||
.inner
|
||||
.decode(Some(packet), signals, false)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("opus decode: {e}")))?;
|
||||
Ok(n)
|
||||
self.inner.decode(encoded, pcm)
|
||||
}
|
||||
|
||||
fn decode_lost(&mut self, pcm: &mut [i16]) -> Result<usize, CodecError> {
|
||||
@@ -64,13 +76,7 @@ impl AudioDecoder for OpusDecoder {
|
||||
pcm.len()
|
||||
)));
|
||||
}
|
||||
let signals = MutSignals::try_from(pcm)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("output signals: {e}")))?;
|
||||
let n = self
|
||||
.inner
|
||||
.decode(None, signals, false)
|
||||
.map_err(|e| CodecError::DecodeFailed(format!("opus PLC: {e}")))?;
|
||||
Ok(n)
|
||||
self.inner.decode_lost(pcm)
|
||||
}
|
||||
|
||||
fn codec_id(&self) -> CodecId {
|
||||
@@ -79,7 +85,7 @@ impl AudioDecoder for OpusDecoder {
|
||||
|
||||
fn set_profile(&mut self, profile: QualityProfile) -> Result<(), CodecError> {
|
||||
match profile.codec {
|
||||
CodecId::Opus24k | CodecId::Opus16k | CodecId::Opus6k => {
|
||||
c if c.is_opus() => {
|
||||
self.codec_id = profile.codec;
|
||||
self.frame_duration_ms = profile.frame_duration_ms;
|
||||
Ok(())
|
||||
|
||||
@@ -1,58 +1,230 @@
|
||||
//! Opus encoder wrapping the `audiopus` crate.
|
||||
//! Opus encoder wrapping the `opusic-c` crate (libopus 1.5.2).
|
||||
//!
|
||||
//! Phase 1 of the DRED integration: encoder-side DRED is enabled on every
|
||||
//! Opus profile with a tiered duration (studio 100 ms / normal 200 ms /
|
||||
//! degraded 500 ms), and Opus inband FEC (LBRR) is disabled because DRED
|
||||
//! is the stronger mechanism for the same failure mode. The legacy behavior
|
||||
//! is preserved behind the `AUDIO_USE_LEGACY_FEC` environment variable as a
|
||||
//! runtime escape hatch for rollout. See `docs/PRD-dred-integration.md`.
|
||||
//!
|
||||
//! # DRED duration policy
|
||||
//!
|
||||
//! Rationale from the PRD:
|
||||
//! - Studio tiers (Opus 32k/48k/64k): 100 ms — loss is rare on high-quality
|
||||
//! networks; short window keeps decoder CPU modest.
|
||||
//! - Normal tiers (Opus 16k/24k): 200 ms — balanced baseline covering common
|
||||
//! VoIP loss patterns (20–150 ms bursts from wifi roam, transient congestion).
|
||||
//! - Degraded tier (Opus 6k): 1040 ms — users on 6k are by definition on a
|
||||
//! bad link; the maximum libopus DRED window buys the best burst resilience
|
||||
//! where it matters. The RDO-VAE naturally degrades quality at longer offsets.
|
||||
//!
|
||||
//! # Why the 15% packet loss floor
|
||||
//!
|
||||
//! libopus 1.5's DRED emitter is gated on `OPUS_SET_PACKET_LOSS_PERC` and
|
||||
//! scales the emitted window proportionally to the assumed loss:
|
||||
//!
|
||||
//! ```text
|
||||
//! loss_pct samples_available effective_ms
|
||||
//! 5% 720 15
|
||||
//! 10% 2640 55
|
||||
//! 15% 4560 95
|
||||
//! 20% 6480 135
|
||||
//! 25%+ 8400 (capped) 175 (≈ 87% of the 200ms configured max)
|
||||
//! ```
|
||||
//!
|
||||
//! Measured empirically against libopus 1.5.2 on Opus 24k / 200 ms DRED
|
||||
//! duration during Phase 3b. At 5% loss the window is only 15 ms — too
|
||||
//! small to even reconstruct a single 20 ms Opus frame. 15% gives 95 ms
|
||||
//! (enough for single-frame recovery plus modest burst margin) while
|
||||
//! keeping the bitrate overhead modest compared to 25%. Real measurements
|
||||
//! from the quality adapter override upward when loss exceeds the floor.
|
||||
|
||||
use audiopus::coder::Encoder;
|
||||
use audiopus::{Application, Bitrate, Channels, SampleRate, Signal};
|
||||
use tracing::debug;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use opusic_c::{Application, Bitrate, Channels, Encoder, InbandFec, SampleRate, Signal};
|
||||
use tracing::{debug, info, warn};
|
||||
use wzp_proto::{AudioEncoder, CodecError, CodecId, QualityProfile};
|
||||
|
||||
/// Logged exactly once per process the first time an OpusEncoder is built.
|
||||
/// Confirms that libopus 1.5.2 (the version with DRED) is actually linked
|
||||
/// at runtime — invaluable when chasing "is the new codec loaded?"
|
||||
/// regressions on Android, where the only debug surface is logcat.
|
||||
static LIBOPUS_VERSION_LOGGED: OnceLock<()> = OnceLock::new();
|
||||
|
||||
/// Minimum `OPUS_SET_PACKET_LOSS_PERC` value used in DRED mode. libopus
|
||||
/// scales the DRED emission window with the assumed loss percentage:
|
||||
/// empirically, 5% gives a 15 ms window (useless), 10% gives 55 ms, 15%
|
||||
/// gives 95 ms, and 25%+ saturates the configured max (~175 ms at 200 ms
|
||||
/// duration). 15% is the minimum value that produces a DRED window larger
|
||||
/// than a single 20 ms frame, making it the minimum floor that actually
|
||||
/// gives DRED something useful to reconstruct. Real loss measurements from
|
||||
/// the quality adapter override this upward.
|
||||
const DRED_LOSS_FLOOR_PCT: u8 = 15;
|
||||
|
||||
/// Environment variable that reverts Phase 1 behavior to Phase 0 (inband FEC
|
||||
/// on, DRED off, no loss floor). Read once per encoder construction.
|
||||
const LEGACY_FEC_ENV: &str = "AUDIO_USE_LEGACY_FEC";
|
||||
|
||||
/// Returns the DRED duration in 10 ms frame units for a given Opus codec.
|
||||
///
|
||||
/// Unit: each frame is 10 ms, so the max value of 104 corresponds to 1040 ms
|
||||
/// of reconstructable history. Returns 0 for non-Opus codecs (DRED is not
|
||||
/// emitted by the libopus encoder in that case anyway, but we avoid a
|
||||
/// pointless FFI call).
|
||||
///
|
||||
/// See the DRED duration policy in the module docs for per-tier rationale.
|
||||
pub fn dred_duration_for(codec: CodecId) -> u8 {
|
||||
match codec {
|
||||
// Studio tiers — loss is rare, short window.
|
||||
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10,
|
||||
// Normal tiers — balanced baseline.
|
||||
CodecId::Opus16k | CodecId::Opus24k => 20,
|
||||
// Degraded tier — maximum burst resilience. 104 × 10 ms = 1040 ms,
|
||||
// the highest value libopus 1.5 supports. Users on 6k are on a bad
|
||||
// link by definition; the RDO-VAE naturally degrades quality at longer
|
||||
// offsets, so the extra window costs only ~1-2 kbps additional overhead
|
||||
// while buying substantially better burst resilience (up from 500 ms).
|
||||
CodecId::Opus6k => 104,
|
||||
// Non-Opus (Codec2 / CN / video): DRED is N/A.
|
||||
CodecId::Codec2_1200
|
||||
| CodecId::Codec2_3200
|
||||
| CodecId::ComfortNoise
|
||||
| CodecId::H264Baseline
|
||||
| CodecId::H265Main
|
||||
| CodecId::Av1Main => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the legacy-FEC escape hatch is active.
|
||||
///
|
||||
/// Read from `AUDIO_USE_LEGACY_FEC`. Any non-empty value activates legacy
|
||||
/// mode; unset or empty leaves DRED enabled.
|
||||
fn read_legacy_fec_env() -> bool {
|
||||
match std::env::var(LEGACY_FEC_ENV) {
|
||||
Ok(v) => !v.is_empty() && v != "0" && v.to_ascii_lowercase() != "false",
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus encoder implementing `AudioEncoder`.
|
||||
///
|
||||
/// Operates at 48 kHz mono. Supports frame sizes of 20 ms (960 samples)
|
||||
/// and 40 ms (1920 samples).
|
||||
/// Operates at 48 kHz mono. Supports 20 ms and 40 ms frames via the active
|
||||
/// `QualityProfile`.
|
||||
pub struct OpusEncoder {
|
||||
inner: Encoder,
|
||||
codec_id: CodecId,
|
||||
frame_duration_ms: u8,
|
||||
/// When `true`, revert to the Phase 0 behavior: inband FEC Mode1, DRED
|
||||
/// disabled, no loss floor. Captured at construction time and not
|
||||
/// re-read mid-call.
|
||||
legacy_fec_mode: bool,
|
||||
}
|
||||
|
||||
// SAFETY: OpusEncoder is only used via `&mut self` methods. The inner
|
||||
// audiopus Encoder contains a raw pointer that is !Sync, but we never
|
||||
// share it across threads without exclusive access.
|
||||
// opusic-c Encoder wraps a non-null pointer that is !Sync by default,
|
||||
// but we never share it across threads without exclusive access.
|
||||
unsafe impl Sync for OpusEncoder {}
|
||||
|
||||
impl OpusEncoder {
|
||||
/// Create a new Opus encoder for the given quality profile.
|
||||
pub fn new(profile: QualityProfile) -> Result<Self, CodecError> {
|
||||
let encoder = Encoder::new(SampleRate::Hz48000, Channels::Mono, Application::Voip)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("opus encoder init: {e}")))?;
|
||||
// opusic-c argument order: (Channels, SampleRate, Application)
|
||||
// — different from audiopus's (SampleRate, Channels, Application).
|
||||
let encoder = Encoder::new(Channels::Mono, SampleRate::Hz48000, Application::Voip)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("opus encoder init: {e:?}")))?;
|
||||
|
||||
let legacy_fec_mode = read_legacy_fec_env();
|
||||
if legacy_fec_mode {
|
||||
warn!(
|
||||
"AUDIO_USE_LEGACY_FEC active — reverting Opus encoder to Phase 0 \
|
||||
behavior (inband FEC Mode1, no DRED)"
|
||||
);
|
||||
}
|
||||
|
||||
let mut enc = Self {
|
||||
inner: encoder,
|
||||
codec_id: profile.codec,
|
||||
frame_duration_ms: profile.frame_duration_ms,
|
||||
legacy_fec_mode,
|
||||
};
|
||||
enc.apply_bitrate(profile.codec)?;
|
||||
enc.set_inband_fec(true);
|
||||
enc.set_dtx(true);
|
||||
|
||||
// Voice signal type hint for better compression
|
||||
// Common setup — bitrate, DTX, signal hint, complexity. These are
|
||||
// identical regardless of the protection mode below.
|
||||
enc.apply_bitrate(profile.codec)?;
|
||||
enc.set_dtx(true);
|
||||
enc.inner
|
||||
.set_signal(Signal::Voice)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set signal: {e}")))?;
|
||||
|
||||
// Default complexity 7 — good quality/CPU trade-off for VoIP
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set signal: {e:?}")))?;
|
||||
enc.inner
|
||||
.set_complexity(7)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set complexity: {e}")))?;
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set complexity: {e:?}")))?;
|
||||
|
||||
// Protection mode: DRED (Phase 1 default) or legacy inband FEC.
|
||||
enc.apply_protection_mode(profile.codec)?;
|
||||
|
||||
Ok(enc)
|
||||
}
|
||||
|
||||
fn apply_bitrate(&mut self, codec: CodecId) -> Result<(), CodecError> {
|
||||
let bps = codec.bitrate_bps() as i32;
|
||||
/// Configure the protection mode for the active codec.
|
||||
///
|
||||
/// In DRED mode (default): disable inband FEC, set DRED duration for the
|
||||
/// codec tier, clamp packet_loss to the 5% floor so DRED stays active.
|
||||
///
|
||||
/// In legacy mode: enable inband FEC Mode1 (Phase 0 behavior), leave
|
||||
/// DRED and packet_loss at libopus defaults.
|
||||
fn apply_protection_mode(&mut self, codec: CodecId) -> Result<(), CodecError> {
|
||||
if self.legacy_fec_mode {
|
||||
self.inner
|
||||
.set_inband_fec(InbandFec::Mode1)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set inband FEC: {e:?}")))?;
|
||||
// Leave DRED at 0 and packet_loss at default — matches Phase 0.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// DRED path: disable the overlapping inband FEC, enable DRED with
|
||||
// per-profile duration, floor packet_loss so DRED emits.
|
||||
self.inner
|
||||
.set_bitrate(Bitrate::BitsPerSecond(bps))
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set bitrate: {e}")))?;
|
||||
.set_inband_fec(InbandFec::Off)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set inband FEC off: {e:?}")))?;
|
||||
|
||||
let dred_frames = dred_duration_for(codec);
|
||||
self.inner
|
||||
.set_dred_duration(dred_frames)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set DRED duration: {e:?}")))?;
|
||||
|
||||
self.inner
|
||||
.set_packet_loss(DRED_LOSS_FLOOR_PCT)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set packet loss floor: {e:?}")))?;
|
||||
|
||||
// Both of these are gated behind the GUI debug toggle so logcat
|
||||
// stays clean in normal mode. Flip "DRED verbose logs" in the
|
||||
// settings panel to see the per-encoder config + libopus version.
|
||||
if crate::dred_verbose_logs() {
|
||||
info!(
|
||||
codec = ?codec,
|
||||
dred_frames,
|
||||
dred_ms = dred_frames as u32 * 10,
|
||||
loss_floor_pct = DRED_LOSS_FLOOR_PCT,
|
||||
"opus encoder: DRED enabled"
|
||||
);
|
||||
|
||||
// One-shot logging of the linked libopus version so we can
|
||||
// confirm at a glance that opusic-c (libopus 1.5.2) is loaded.
|
||||
// Pre-Phase-0 audiopus shipped libopus 1.3 which has no DRED;
|
||||
// if this log says "libopus 1.3" something is very wrong.
|
||||
LIBOPUS_VERSION_LOGGED.get_or_init(|| {
|
||||
info!(libopus_version = %opusic_c::version(), "linked libopus version");
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_bitrate(&mut self, codec: CodecId) -> Result<(), CodecError> {
|
||||
let bps = codec.bitrate_bps();
|
||||
self.inner
|
||||
.set_bitrate(Bitrate::Value(bps))
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("set bitrate: {e:?}")))?;
|
||||
debug!(bitrate_bps = bps, "opus encoder bitrate set");
|
||||
Ok(())
|
||||
}
|
||||
@@ -71,10 +243,36 @@ impl OpusEncoder {
|
||||
|
||||
/// Hint the encoder about expected packet loss percentage (0-100).
|
||||
///
|
||||
/// Higher values cause the encoder to use more redundancy to survive
|
||||
/// packet loss, at the expense of slightly higher bitrate.
|
||||
/// In DRED mode, the value is floored at `DRED_LOSS_FLOOR_PCT` so the
|
||||
/// encoder never drops DRED emission even on a perfect network. Real
|
||||
/// loss measurements from the quality adapter override upward.
|
||||
///
|
||||
/// In legacy mode, the value is passed through unchanged (min 0, max 100).
|
||||
pub fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||
let _ = self.inner.set_packet_loss_perc(loss_pct.min(100));
|
||||
let clamped = if self.legacy_fec_mode {
|
||||
loss_pct.min(100)
|
||||
} else {
|
||||
loss_pct.max(DRED_LOSS_FLOOR_PCT).min(100)
|
||||
};
|
||||
let _ = self.inner.set_packet_loss(clamped);
|
||||
}
|
||||
|
||||
/// Set the DRED duration in 10 ms frame units (0 disables, max 104).
|
||||
///
|
||||
/// No-op in legacy mode. Normally driven automatically by the active
|
||||
/// quality profile via `apply_protection_mode`; this setter exists for
|
||||
/// tests and for the rare case where a caller needs to override the
|
||||
/// per-profile default.
|
||||
pub fn set_dred_duration(&mut self, frames: u8) {
|
||||
if self.legacy_fec_mode {
|
||||
return;
|
||||
}
|
||||
let _ = self.inner.set_dred_duration(frames.min(104));
|
||||
}
|
||||
|
||||
/// Test/introspection accessor: whether legacy FEC mode is active.
|
||||
pub fn is_legacy_fec_mode(&self) -> bool {
|
||||
self.legacy_fec_mode
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +285,14 @@ impl AudioEncoder for OpusEncoder {
|
||||
pcm.len()
|
||||
)));
|
||||
}
|
||||
// opusic-c takes &[u16] for the sample input. Bit pattern is
|
||||
// identical to i16 — the cast is zero-cost and the encoder
|
||||
// interprets the bytes the same way as libopus internally.
|
||||
let pcm_u16: &[u16] = bytemuck::cast_slice(pcm);
|
||||
let n = self
|
||||
.inner
|
||||
.encode(pcm, out)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("opus encode: {e}")))?;
|
||||
.encode_to_slice(pcm_u16, out)
|
||||
.map_err(|e| CodecError::EncodeFailed(format!("opus encode: {e:?}")))?;
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
@@ -100,10 +302,13 @@ impl AudioEncoder for OpusEncoder {
|
||||
|
||||
fn set_profile(&mut self, profile: QualityProfile) -> Result<(), CodecError> {
|
||||
match profile.codec {
|
||||
CodecId::Opus24k | CodecId::Opus16k | CodecId::Opus6k => {
|
||||
c if c.is_opus() => {
|
||||
self.codec_id = profile.codec;
|
||||
self.frame_duration_ms = profile.frame_duration_ms;
|
||||
self.apply_bitrate(profile.codec)?;
|
||||
// Refresh DRED duration for the new tier. apply_protection_mode
|
||||
// is idempotent and handles the legacy-vs-DRED branch correctly.
|
||||
self.apply_protection_mode(profile.codec)?;
|
||||
Ok(())
|
||||
}
|
||||
other => Err(CodecError::UnsupportedTransition {
|
||||
@@ -120,10 +325,202 @@ impl AudioEncoder for OpusEncoder {
|
||||
}
|
||||
|
||||
fn set_inband_fec(&mut self, enabled: bool) {
|
||||
let _ = self.inner.set_inband_fec(enabled);
|
||||
// In DRED mode, ignore external requests to re-enable inband FEC —
|
||||
// running both mechanisms wastes bitrate on overlapping protection
|
||||
// and opusic-c's own docs recommend disabling inband FEC when DRED
|
||||
// is on. Trait callers that genuinely want classical FEC should set
|
||||
// `AUDIO_USE_LEGACY_FEC=1` and re-create the encoder.
|
||||
if !self.legacy_fec_mode {
|
||||
debug!(
|
||||
enabled,
|
||||
"set_inband_fec ignored: DRED mode is active (set AUDIO_USE_LEGACY_FEC to revert)"
|
||||
);
|
||||
return;
|
||||
}
|
||||
let mode = if enabled {
|
||||
InbandFec::Mode1
|
||||
} else {
|
||||
InbandFec::Off
|
||||
};
|
||||
let _ = self.inner.set_inband_fec(mode);
|
||||
}
|
||||
|
||||
fn set_dtx(&mut self, enabled: bool) {
|
||||
let _ = self.inner.set_dtx(enabled);
|
||||
}
|
||||
|
||||
fn set_expected_loss(&mut self, loss_pct: u8) {
|
||||
OpusEncoder::set_expected_loss(self, loss_pct);
|
||||
}
|
||||
|
||||
fn set_dred_duration(&mut self, frames: u8) {
|
||||
OpusEncoder::set_dred_duration(self, frames);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::AudioDecoder;
|
||||
|
||||
/// Phase 0 acceptance gate: fail loudly if the linked libopus is not 1.5.x.
|
||||
/// DRED (Phase 1+) only exists in libopus ≥ 1.5, so running against an
|
||||
/// older version would silently regress the entire DRED integration.
|
||||
#[test]
|
||||
fn linked_libopus_is_1_5() {
|
||||
let version = opusic_c::version();
|
||||
assert!(
|
||||
version.contains("1.5"),
|
||||
"expected libopus 1.5.x, got: {version}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoder_creates_at_good_profile() {
|
||||
let enc = OpusEncoder::new(QualityProfile::GOOD).expect("opus encoder init");
|
||||
assert_eq!(enc.codec_id, CodecId::Opus24k);
|
||||
assert_eq!(enc.frame_samples(), 960); // 20 ms @ 48 kHz
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoder_roundtrip_silence() {
|
||||
let mut enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
let mut dec = crate::opus_dec::OpusDecoder::new(QualityProfile::GOOD).unwrap();
|
||||
let pcm_in = vec![0i16; 960]; // 20 ms silence
|
||||
let mut encoded = vec![0u8; 512];
|
||||
let n = enc.encode(&pcm_in, &mut encoded).unwrap();
|
||||
assert!(n > 0);
|
||||
let mut pcm_out = vec![0i16; 960];
|
||||
let samples = dec.decode(&encoded[..n], &mut pcm_out).unwrap();
|
||||
assert_eq!(samples, 960);
|
||||
}
|
||||
|
||||
// ─── Phase 1 — DRED duration policy ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dred_duration_for_studio_tiers_is_100ms() {
|
||||
assert_eq!(dred_duration_for(CodecId::Opus32k), 10);
|
||||
assert_eq!(dred_duration_for(CodecId::Opus48k), 10);
|
||||
assert_eq!(dred_duration_for(CodecId::Opus64k), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_duration_for_normal_tiers_is_200ms() {
|
||||
assert_eq!(dred_duration_for(CodecId::Opus16k), 20);
|
||||
assert_eq!(dred_duration_for(CodecId::Opus24k), 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_duration_for_degraded_tier_is_1040ms() {
|
||||
assert_eq!(dred_duration_for(CodecId::Opus6k), 104);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dred_duration_for_codec2_is_zero() {
|
||||
assert_eq!(dred_duration_for(CodecId::Codec2_3200), 0);
|
||||
assert_eq!(dred_duration_for(CodecId::Codec2_1200), 0);
|
||||
assert_eq!(dred_duration_for(CodecId::ComfortNoise), 0);
|
||||
}
|
||||
|
||||
// ─── Phase 1 — Legacy escape hatch ──────────────────────────────────────
|
||||
|
||||
/// By default (env var unset), legacy mode is off.
|
||||
///
|
||||
/// This test does NOT manipulate the environment to avoid flakiness
|
||||
/// when the full suite runs in parallel. It only asserts on a freshly
|
||||
/// created encoder in the ambient environment.
|
||||
#[test]
|
||||
fn default_mode_is_dred_not_legacy() {
|
||||
// SAFETY: only run if the ambient env hasn't set the var externally.
|
||||
if std::env::var(LEGACY_FEC_ENV).is_ok() {
|
||||
return; // don't assert — someone set the env for a reason.
|
||||
}
|
||||
let enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
assert!(!enc.is_legacy_fec_mode());
|
||||
}
|
||||
|
||||
// ─── Phase 1 — Behavioral regression: roundtrip still works ─────────────
|
||||
|
||||
#[test]
|
||||
fn dred_mode_roundtrip_voice_pattern() {
|
||||
// Use a realistic voice-like input (sine wave at speech frequencies)
|
||||
// so the encoder emits meaningful DRED data rather than trivially
|
||||
// compressible silence.
|
||||
let mut enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
let mut dec = crate::opus_dec::OpusDecoder::new(QualityProfile::GOOD).unwrap();
|
||||
|
||||
let mut total_encoded_bytes = 0usize;
|
||||
// Run 50 frames (1 second) so DRED fills up and starts emitting.
|
||||
for frame_idx in 0..50 {
|
||||
let pcm_in: Vec<i16> = (0..960)
|
||||
.map(|i| {
|
||||
let t = (frame_idx * 960 + i) as f64 / 48_000.0;
|
||||
(8000.0 * (2.0 * std::f64::consts::PI * 300.0 * t).sin()) as i16
|
||||
})
|
||||
.collect();
|
||||
let mut encoded = vec![0u8; 512];
|
||||
let n = enc.encode(&pcm_in, &mut encoded).unwrap();
|
||||
assert!(n > 0);
|
||||
total_encoded_bytes += n;
|
||||
|
||||
let mut pcm_out = vec![0i16; 960];
|
||||
let samples = dec.decode(&encoded[..n], &mut pcm_out).unwrap();
|
||||
assert_eq!(samples, 960);
|
||||
}
|
||||
|
||||
// Effective bitrate after 1 second of encoding.
|
||||
// Opus 24k base + ~1 kbps DRED ≈ 25 kbps ≈ 3125 bytes/sec.
|
||||
// Allow generous headroom (2000 lower bound, 8000 upper bound) —
|
||||
// this is a behavioral regression check, not a tight bitrate assertion.
|
||||
// The exact value is printed with --nocapture for diagnostic use.
|
||||
eprintln!(
|
||||
"[phase1 bitrate probe] legacy_fec_mode={} total_encoded={} bytes/sec",
|
||||
enc.is_legacy_fec_mode(),
|
||||
total_encoded_bytes
|
||||
);
|
||||
assert!(
|
||||
total_encoded_bytes > 2000,
|
||||
"encoder output too small: {total_encoded_bytes} bytes/sec (DRED likely not emitting)"
|
||||
);
|
||||
assert!(
|
||||
total_encoded_bytes < 8000,
|
||||
"encoder output too large: {total_encoded_bytes} bytes/sec"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Phase 1 — set_profile updates DRED duration on tier switch ─────────
|
||||
|
||||
#[test]
|
||||
fn profile_switch_refreshes_dred_duration() {
|
||||
// Start on GOOD (Opus 24k, DRED 20 frames), switch to DEGRADED
|
||||
// (Opus 6k, DRED 50 frames). The encoder should accept both profile
|
||||
// changes without error. We can't directly observe the DRED duration
|
||||
// inside libopus, but apply_protection_mode returns Ok for both.
|
||||
let mut enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
assert_eq!(enc.codec_id, CodecId::Opus24k);
|
||||
|
||||
enc.set_profile(QualityProfile::DEGRADED).unwrap();
|
||||
assert_eq!(enc.codec_id, CodecId::Opus6k);
|
||||
|
||||
enc.set_profile(QualityProfile::STUDIO_64K).unwrap();
|
||||
assert_eq!(enc.codec_id, CodecId::Opus64k);
|
||||
}
|
||||
|
||||
// ─── Phase 1 — Trait set_inband_fec is a no-op in DRED mode ─────────────
|
||||
|
||||
#[test]
|
||||
fn set_inband_fec_noop_in_dred_mode() {
|
||||
if std::env::var(LEGACY_FEC_ENV).is_ok() {
|
||||
return;
|
||||
}
|
||||
let mut enc = OpusEncoder::new(QualityProfile::GOOD).unwrap();
|
||||
// Should not error, should not re-enable inband FEC internally.
|
||||
enc.set_inband_fec(true);
|
||||
// We can't directly query libopus's inband FEC state through opusic-c,
|
||||
// but the call must not panic and the encoder must still work.
|
||||
let pcm_in = vec![0i16; 960];
|
||||
let mut encoded = vec![0u8; 512];
|
||||
let n = enc.encode(&pcm_in, &mut encoded).unwrap();
|
||||
assert!(n > 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,8 +129,7 @@ impl Downsampler48to8 {
|
||||
|
||||
// Update history: keep the last (FIR_TAPS - 1) samples from work.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
// Input was shorter than history — shift.
|
||||
let shift = hist_len - work.len();
|
||||
@@ -209,8 +208,7 @@ impl Upsampler8to48 {
|
||||
|
||||
// Update history.
|
||||
if work.len() >= hist_len {
|
||||
self.history
|
||||
.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
self.history.copy_from_slice(&work[work.len() - hist_len..]);
|
||||
} else {
|
||||
let shift = hist_len - work.len();
|
||||
self.history.copy_within(shift.., 0);
|
||||
|
||||
@@ -151,7 +151,10 @@ mod tests {
|
||||
for _ in 0..4 {
|
||||
det.is_silent(&silence);
|
||||
}
|
||||
assert!(det.is_silent(&silence), "should be suppressing after hangover");
|
||||
assert!(
|
||||
det.is_silent(&silence),
|
||||
"should be suppressing after hangover"
|
||||
);
|
||||
|
||||
// Speech arrives — should immediately stop suppressing.
|
||||
assert!(!det.is_silent(&speech));
|
||||
@@ -165,10 +168,16 @@ mod tests {
|
||||
cn.generate(&mut pcm);
|
||||
|
||||
// At least some samples should be non-zero.
|
||||
assert!(pcm.iter().any(|&s| s != 0), "CN output should not be all zeros");
|
||||
assert!(
|
||||
pcm.iter().any(|&s| s != 0),
|
||||
"CN output should not be all zeros"
|
||||
);
|
||||
|
||||
// All samples should be within [-50, 50].
|
||||
assert!(pcm.iter().all(|&s| s.abs() <= 50), "CN samples out of range");
|
||||
assert!(
|
||||
pcm.iter().all(|&s| s.abs() <= 50),
|
||||
"CN samples out of range"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -179,11 +188,17 @@ mod tests {
|
||||
// Constant value: RMS of [v, v, v, ...] = |v|.
|
||||
let pcm = vec![100i16; 100];
|
||||
let rms = SilenceDetector::rms(&pcm);
|
||||
assert!((rms - 100.0).abs() < 0.01, "RMS of constant 100 should be 100, got {rms}");
|
||||
assert!(
|
||||
(rms - 100.0).abs() < 0.01,
|
||||
"RMS of constant 100 should be 100, got {rms}"
|
||||
);
|
||||
|
||||
// Known pattern: [3, 4] → sqrt((9+16)/2) = sqrt(12.5) ≈ 3.5355
|
||||
let rms2 = SilenceDetector::rms(&[3, 4]);
|
||||
assert!((rms2 - 3.5355).abs() < 0.01, "RMS of [3,4] should be ~3.5355, got {rms2}");
|
||||
assert!(
|
||||
(rms2 - 3.5355).abs() < 0.01,
|
||||
"RMS of [3,4] should be ~3.5355, got {rms2}"
|
||||
);
|
||||
|
||||
// Empty buffer → 0.
|
||||
assert_eq!(SilenceDetector::rms(&[]), 0.0);
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
//! Sliding window replay protection.
|
||||
//!
|
||||
//! Tracks seen sequence numbers using a bitmap. Window size is 1024 packets.
|
||||
//! Sequence numbers that are too old (more than WINDOW_SIZE behind the highest
|
||||
//! seen) are rejected.
|
||||
//! Tracks seen sequence numbers using a bitmap. Window size is configurable
|
||||
//! at construction time. Sequence numbers that are too old (more than
|
||||
//! `window_size` behind the highest seen) are rejected.
|
||||
|
||||
use wzp_proto::CryptoError;
|
||||
|
||||
/// Window size in packets.
|
||||
const WINDOW_SIZE: u16 = 1024;
|
||||
|
||||
/// Sliding window anti-replay detector.
|
||||
///
|
||||
/// Uses a bitmap to track which sequence numbers have been seen within
|
||||
/// the current window. Handles u16 wrapping correctly.
|
||||
/// the current window. Handles `u32` wrapping correctly.
|
||||
pub struct AntiReplayWindow {
|
||||
/// Window size in packets.
|
||||
window_size: u32,
|
||||
/// Highest sequence number seen so far.
|
||||
highest: u16,
|
||||
highest: u32,
|
||||
/// Bitmap of seen packets. Bit i corresponds to (highest - i).
|
||||
bitmap: Vec<u64>,
|
||||
/// Whether any packet has been received yet.
|
||||
@@ -23,21 +22,26 @@ pub struct AntiReplayWindow {
|
||||
}
|
||||
|
||||
impl AntiReplayWindow {
|
||||
/// Number of u64 words needed for the bitmap.
|
||||
const BITMAP_WORDS: usize = (WINDOW_SIZE as usize + 63) / 64;
|
||||
|
||||
/// Create a new anti-replay window.
|
||||
/// Create a new anti-replay window with the default size of 1024 packets.
|
||||
pub fn new() -> Self {
|
||||
Self::with_window(1024)
|
||||
}
|
||||
|
||||
/// Create a new anti-replay window with a custom size.
|
||||
pub fn with_window(size: usize) -> Self {
|
||||
let window_size = size as u32;
|
||||
let bitmap_words = (size + 63) / 64;
|
||||
Self {
|
||||
window_size,
|
||||
highest: 0,
|
||||
bitmap: vec![0u64; Self::BITMAP_WORDS],
|
||||
bitmap: vec![0u64; bitmap_words],
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a sequence number is valid (not a replay, not too old).
|
||||
/// If valid, marks it as seen.
|
||||
pub fn check_and_update(&mut self, seq: u16) -> Result<(), CryptoError> {
|
||||
pub fn check_and_update(&mut self, seq: u32) -> Result<(), CryptoError> {
|
||||
if !self.initialized {
|
||||
self.initialized = true;
|
||||
self.highest = seq;
|
||||
@@ -52,17 +56,17 @@ impl AntiReplayWindow {
|
||||
return Err(CryptoError::ReplayDetected { seq });
|
||||
}
|
||||
|
||||
if diff < 0x8000 {
|
||||
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF])
|
||||
if diff < 0x8000_0000 {
|
||||
// seq is ahead of highest (wrapping-aware: diff in [1, 0x7FFF_FFFF])
|
||||
let shift = diff as usize;
|
||||
self.advance_window(shift);
|
||||
self.highest = seq;
|
||||
self.set_bit(0);
|
||||
Ok(())
|
||||
} else {
|
||||
// seq is behind highest (wrapping-aware: diff in [0x8000, 0xFFFF])
|
||||
// seq is behind highest (wrapping-aware: diff in [0x8000_0000, 0xFFFF_FFFF])
|
||||
let behind = self.highest.wrapping_sub(seq) as usize;
|
||||
if behind >= WINDOW_SIZE as usize {
|
||||
if behind >= self.window_size as usize {
|
||||
return Err(CryptoError::ReplayDetected { seq });
|
||||
}
|
||||
if self.get_bit(behind) {
|
||||
@@ -75,7 +79,8 @@ impl AntiReplayWindow {
|
||||
|
||||
/// Advance the window by `shift` positions (shift left = new bits at position 0).
|
||||
fn advance_window(&mut self, shift: usize) {
|
||||
if shift >= WINDOW_SIZE as usize {
|
||||
let window_size = self.window_size as usize;
|
||||
if shift >= window_size {
|
||||
for word in &mut self.bitmap {
|
||||
*word = 0;
|
||||
}
|
||||
@@ -156,7 +161,11 @@ mod tests {
|
||||
fn sequential_accepted() {
|
||||
let mut w = AntiReplayWindow::new();
|
||||
for i in 0..200 {
|
||||
assert!(w.check_and_update(i).is_ok(), "seq {} should be accepted", i);
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,11 +192,11 @@ mod tests {
|
||||
#[test]
|
||||
fn wrapping_works() {
|
||||
let mut w = AntiReplayWindow::new();
|
||||
assert!(w.check_and_update(65530).is_ok());
|
||||
assert!(w.check_and_update(65535).is_ok());
|
||||
assert!(w.check_and_update(0xFFFF_FFF0).is_ok());
|
||||
assert!(w.check_and_update(0xFFFF_FFFF).is_ok());
|
||||
assert!(w.check_and_update(0).is_ok()); // wrapped
|
||||
assert!(w.check_and_update(1).is_ok());
|
||||
assert!(w.check_and_update(65535).is_err()); // duplicate
|
||||
assert!(w.check_and_update(0xFFFF_FFFF).is_err()); // duplicate
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -201,4 +210,53 @@ mod tests {
|
||||
// Now 0 is 1024 behind 1024, which is at the boundary limit
|
||||
assert!(w.check_and_update(0).is_err()); // already seen or too old
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_window_size() {
|
||||
let mut w = AntiReplayWindow::with_window(64);
|
||||
for i in 0..64 {
|
||||
assert!(w.check_and_update(i).is_ok());
|
||||
}
|
||||
// seq 0 is now exactly at the boundary (64 behind 64)
|
||||
assert!(w.check_and_update(0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn video_burst_200_with_one_reorder() {
|
||||
let mut w = AntiReplayWindow::with_window(1024);
|
||||
// Simulate a 200-packet burst
|
||||
for i in 0..200 {
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
// One packet reordered (arrives late)
|
||||
assert!(w.check_and_update(50).is_err(), "seq 50 is a duplicate");
|
||||
// But a packet just behind the window should still be ok
|
||||
assert!(w.check_and_update(199).is_err(), "seq 199 is a duplicate");
|
||||
// Continue the burst
|
||||
for i in 200..400 {
|
||||
assert!(
|
||||
w.check_and_update(i).is_ok(),
|
||||
"seq {} should be accepted",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn u32_high_range_works() {
|
||||
let mut w = AntiReplayWindow::with_window(64);
|
||||
let base = 1000u32;
|
||||
assert!(w.check_and_update(base).is_ok());
|
||||
assert!(w.check_and_update(base + 1).is_ok());
|
||||
// 65 behind highest (base+1) is outside the 64-packet window
|
||||
assert!(w.check_and_update(base.wrapping_sub(64)).is_err());
|
||||
// 63 behind is inside
|
||||
assert!(w.check_and_update(base.wrapping_sub(62)).is_ok());
|
||||
// base itself is now a duplicate
|
||||
assert!(w.check_and_update(base).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ use ed25519_dalek::{Signer, SigningKey, Verifier, VerifyingKey};
|
||||
use hkdf::Hkdf;
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
||||
use wzp_proto::{CryptoError, CryptoSession, KeyExchange};
|
||||
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
|
||||
|
||||
use crate::session::ChaChaSession;
|
||||
|
||||
@@ -18,10 +18,14 @@ use crate::session::ChaChaSession;
|
||||
pub struct WarzoneKeyExchange {
|
||||
/// Ed25519 signing key (identity).
|
||||
signing_key: SigningKey,
|
||||
/// X25519 static secret (derived from seed, used for identity encryption).
|
||||
/// X25519 static secret derived from identity seed. Reserved for future
|
||||
/// use in static-key federation authentication (not used in current
|
||||
/// ephemeral-only handshake protocol).
|
||||
#[allow(dead_code)]
|
||||
x25519_static_secret: StaticSecret,
|
||||
/// X25519 static public key.
|
||||
/// X25519 static public key derived from identity seed. Reserved for
|
||||
/// future use in static-key federation authentication (not used in
|
||||
/// current ephemeral-only handshake protocol).
|
||||
#[allow(dead_code)]
|
||||
x25519_static_public: X25519PublicKey,
|
||||
/// Ephemeral X25519 secret for the current call (set by generate_ephemeral).
|
||||
@@ -91,12 +95,11 @@ impl KeyExchange for WarzoneKeyExchange {
|
||||
&self,
|
||||
peer_ephemeral_pub: &[u8; 32],
|
||||
) -> Result<Box<dyn CryptoSession>, CryptoError> {
|
||||
let secret = self
|
||||
.ephemeral_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
CryptoError::Internal("no ephemeral key generated; call generate_ephemeral first".into())
|
||||
})?;
|
||||
let secret = self.ephemeral_secret.as_ref().ok_or_else(|| {
|
||||
CryptoError::Internal(
|
||||
"no ephemeral key generated; call generate_ephemeral first".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let peer_public = X25519PublicKey::from(*peer_ephemeral_pub);
|
||||
// Use diffie_hellman with a clone of the StaticSecret
|
||||
@@ -110,7 +113,18 @@ impl KeyExchange for WarzoneKeyExchange {
|
||||
hk.expand(b"warzone-session-key", &mut session_key)
|
||||
.expect("HKDF expand for session key should not fail");
|
||||
|
||||
Ok(Box::new(ChaChaSession::new(session_key)))
|
||||
// Derive SAS (Short Authentication String) from shared secret only.
|
||||
// The shared secret is identical on both sides (X25519 DH property).
|
||||
// A MITM would produce a different shared secret → different SAS.
|
||||
// We use a dedicated HKDF label so SAS is independent of the session key.
|
||||
let mut sas_key = [0u8; 4];
|
||||
hk.expand(b"warzone-sas-code", &mut sas_key)
|
||||
.expect("HKDF expand for SAS should not fail");
|
||||
let sas_code = u32::from_be_bytes(sas_key) % 10000;
|
||||
|
||||
let mut session = ChaChaSession::new(session_key);
|
||||
session.set_sas(sas_code);
|
||||
Ok(Box::new(session))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,20 +209,79 @@ mod tests {
|
||||
let mut alice_session = alice.derive_session(&bob_eph_pub).unwrap();
|
||||
let mut bob_session = bob.derive_session(&alice_eph_pub).unwrap();
|
||||
|
||||
// Verify they can communicate: Alice encrypts, Bob decrypts
|
||||
let header = b"call-header";
|
||||
// Verify they can communicate: Alice encrypts, Bob decrypts.
|
||||
// Use a valid v2 MediaHeader — encrypt/decrypt now derive the nonce from
|
||||
// header.seq and will reject raw byte slices shorter than WIRE_SIZE.
|
||||
use wzp_proto::{CodecId, MediaHeader, MediaType};
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
|
||||
let plaintext = b"hello from alice";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice_session
|
||||
.encrypt(header, plaintext, &mut ciphertext)
|
||||
.encrypt(&header_bytes, plaintext, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
bob_session
|
||||
.decrypt(header, &ciphertext, &mut decrypted)
|
||||
.decrypt(&header_bytes, &ciphertext, &mut decrypted)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(&decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sas_codes_match_between_peers() {
|
||||
let mut alice = WarzoneKeyExchange::from_identity_seed(&[0xAA; 32]);
|
||||
let mut bob = WarzoneKeyExchange::from_identity_seed(&[0xBB; 32]);
|
||||
|
||||
let alice_eph_pub = alice.generate_ephemeral();
|
||||
let bob_eph_pub = bob.generate_ephemeral();
|
||||
|
||||
let alice_session = alice.derive_session(&bob_eph_pub).unwrap();
|
||||
let bob_session = bob.derive_session(&alice_eph_pub).unwrap();
|
||||
|
||||
let alice_sas = alice_session.sas_code();
|
||||
let bob_sas = bob_session.sas_code();
|
||||
|
||||
assert!(alice_sas.is_some(), "Alice should have SAS");
|
||||
assert!(bob_sas.is_some(), "Bob should have SAS");
|
||||
assert_eq!(alice_sas, bob_sas, "SAS codes must match between peers");
|
||||
assert!(alice_sas.unwrap() < 10000, "SAS should be 4 digits");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sas_differs_for_different_peers() {
|
||||
let mut alice = WarzoneKeyExchange::from_identity_seed(&[0xAA; 32]);
|
||||
let mut bob = WarzoneKeyExchange::from_identity_seed(&[0xBB; 32]);
|
||||
let mut eve = WarzoneKeyExchange::from_identity_seed(&[0xEE; 32]);
|
||||
|
||||
let alice_eph = alice.generate_ephemeral();
|
||||
let bob_eph = bob.generate_ephemeral();
|
||||
let eve_eph = eve.generate_ephemeral();
|
||||
|
||||
let alice_bob_session = alice.derive_session(&bob_eph).unwrap();
|
||||
|
||||
// Eve does separate handshake with Bob (MITM scenario)
|
||||
let eve_bob_session = eve.derive_session(&bob_eph).unwrap();
|
||||
|
||||
// SAS codes should differ — Eve's session has different shared secret
|
||||
assert_ne!(
|
||||
alice_bob_session.sas_code(),
|
||||
eve_bob_session.sas_code(),
|
||||
"MITM session should produce different SAS"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +79,9 @@ impl Seed {
|
||||
///
|
||||
/// Mirrors: `warzone-protocol::mnemonic::mnemonic_to_seed`
|
||||
pub fn from_mnemonic(words: &str) -> Result<Self, String> {
|
||||
let mnemonic: bip39::Mnemonic = words.parse().map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||
let mnemonic: bip39::Mnemonic = words
|
||||
.parse()
|
||||
.map_err(|e| format!("invalid mnemonic: {e}"))?;
|
||||
let entropy = mnemonic.to_entropy();
|
||||
if entropy.len() != 32 {
|
||||
return Err(format!("expected 32 bytes entropy, got {}", entropy.len()));
|
||||
|
||||
@@ -16,8 +16,8 @@ pub mod session;
|
||||
|
||||
pub use anti_replay::AntiReplayWindow;
|
||||
pub use handshake::WarzoneKeyExchange;
|
||||
pub use identity::{hash_room_name, Fingerprint, IdentityKeyPair, PublicIdentity, Seed};
|
||||
pub use nonce::{build_nonce, Direction};
|
||||
pub use identity::{Fingerprint, IdentityKeyPair, PublicIdentity, Seed, hash_room_name};
|
||||
pub use nonce::{Direction, build_nonce};
|
||||
pub use rekey::RekeyManager;
|
||||
pub use session::ChaChaSession;
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ impl RekeyManager {
|
||||
///
|
||||
/// The old key is zeroized after the new key is derived.
|
||||
/// Returns the new 32-byte symmetric key.
|
||||
///
|
||||
/// NOTE: Rekeying changes **only** the symmetric key material. Sequence
|
||||
/// numbers and timestamps in the media framing layer (e.g. `MediaHeader`)
|
||||
/// are untouched — they continue monotonically across the rekey boundary.
|
||||
pub fn perform_rekey(
|
||||
&mut self,
|
||||
new_peer_pub: &[u8; 32],
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
//! Implements the `CryptoSession` trait for per-call media encryption.
|
||||
//! Nonces are derived deterministically from session_id + sequence counter + direction.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chacha20poly1305::aead::Aead;
|
||||
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
use rand::rngs::OsRng;
|
||||
use wzp_proto::{CryptoError, CryptoSession};
|
||||
use wzp_proto::{CryptoError, CryptoSession, MediaHeader, MediaType};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use crate::anti_replay::AntiReplayWindow;
|
||||
use crate::nonce::{self, Direction};
|
||||
use crate::rekey::RekeyManager;
|
||||
|
||||
@@ -26,6 +29,12 @@ pub struct ChaChaSession {
|
||||
rekey_mgr: RekeyManager,
|
||||
/// Pending ephemeral secret for rekey (stored until peer responds).
|
||||
pending_rekey_secret: Option<StaticSecret>,
|
||||
/// Short Authentication String (4-digit code for verbal verification).
|
||||
sas_code: Option<u32>,
|
||||
/// Per-stream anti-replay windows, keyed by (stream_id, media_type).
|
||||
anti_replay: HashMap<(u8, MediaType), AntiReplayWindow>,
|
||||
/// Last timestamp seen in encrypt() — used to assert monotonicity across rekeys.
|
||||
last_encrypt_timestamp: Option<u32>,
|
||||
}
|
||||
|
||||
impl ChaChaSession {
|
||||
@@ -46,9 +55,17 @@ impl ChaChaSession {
|
||||
recv_seq: 0,
|
||||
rekey_mgr: RekeyManager::new(shared_secret),
|
||||
pending_rekey_secret: None,
|
||||
sas_code: None,
|
||||
anti_replay: HashMap::new(),
|
||||
last_encrypt_timestamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the SAS code (called by key exchange after derivation).
|
||||
pub fn set_sas(&mut self, code: u32) {
|
||||
self.sas_code = Some(code);
|
||||
}
|
||||
|
||||
/// Install a new key (after rekeying).
|
||||
fn install_key(&mut self, new_key: [u8; 32]) {
|
||||
use sha2::Digest;
|
||||
@@ -59,6 +76,27 @@ impl ChaChaSession {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a v2 `MediaHeader` from raw bytes.
|
||||
/// Returns `None` if the buffer is too short or not a valid v2 header.
|
||||
fn parse_header(header_bytes: &[u8]) -> Option<MediaHeader> {
|
||||
if header_bytes.len() < MediaHeader::WIRE_SIZE {
|
||||
return None;
|
||||
}
|
||||
let mut cursor = std::io::Cursor::new(header_bytes);
|
||||
MediaHeader::read_from(&mut cursor)
|
||||
}
|
||||
|
||||
/// Return the default anti-replay window size for a given media type.
|
||||
fn default_window_for_media_type(media_type: MediaType) -> AntiReplayWindow {
|
||||
let size = match media_type {
|
||||
MediaType::Audio => 64,
|
||||
MediaType::Video => 1024,
|
||||
MediaType::Data => 256,
|
||||
MediaType::Control => 32,
|
||||
};
|
||||
AntiReplayWindow::with_window(size)
|
||||
}
|
||||
|
||||
impl CryptoSession for ChaChaSession {
|
||||
fn encrypt(
|
||||
&mut self,
|
||||
@@ -66,10 +104,14 @@ impl CryptoSession for ChaChaSession {
|
||||
plaintext: &[u8],
|
||||
out: &mut Vec<u8>,
|
||||
) -> Result<(), CryptoError> {
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.send_seq, Direction::Send);
|
||||
// Derive nonce from the wire-level seq in the header, not from an
|
||||
// internal counter. This ensures the receiver can reconstruct the
|
||||
// same nonce using the header it receives, regardless of delivery order.
|
||||
let header = parse_header(header_bytes)
|
||||
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
// Encrypt with AAD
|
||||
use chacha20poly1305::aead::Payload;
|
||||
let payload = Payload {
|
||||
msg: plaintext,
|
||||
@@ -82,7 +124,19 @@ impl CryptoSession for ChaChaSession {
|
||||
.map_err(|_| CryptoError::Internal("encryption failed".into()))?;
|
||||
|
||||
out.extend_from_slice(&ciphertext);
|
||||
self.send_seq = self.send_seq.wrapping_add(1);
|
||||
self.send_seq = self.send_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||
|
||||
// M5: assert timestamp_ms is non-decreasing across calls (including post-rekey).
|
||||
// Timestamps are u32 and wrap at 2^32 ms (~49 days); allow wrapping.
|
||||
debug_assert!(
|
||||
self.last_encrypt_timestamp
|
||||
.map_or(true, |last| header.timestamp.wrapping_sub(last) < u32::MAX / 2),
|
||||
"encrypt: timestamp must not decrease (last={:?}, now={})",
|
||||
self.last_encrypt_timestamp,
|
||||
header.timestamp,
|
||||
);
|
||||
self.last_encrypt_timestamp = Some(header.timestamp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -92,9 +146,14 @@ impl CryptoSession for ChaChaSession {
|
||||
ciphertext: &[u8],
|
||||
out: &mut Vec<u8>,
|
||||
) -> Result<(), CryptoError> {
|
||||
// Use Direction::Send to match the sender's nonce construction.
|
||||
// The recv_seq counter tracks which packet from the peer we're decrypting.
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, self.recv_seq, Direction::Send);
|
||||
// Parse header before decryption — needed for nonce derivation.
|
||||
// Using header.seq (not recv_seq) means the nonce is always derived
|
||||
// from the same wire field as the sender, surviving out-of-order delivery.
|
||||
// A recv_seq counter diverges from the sender's send_seq on any reorder,
|
||||
// causing every subsequent decryption to fail for the rest of the session.
|
||||
let header = parse_header(header_bytes)
|
||||
.ok_or_else(|| CryptoError::Internal("header too short to derive nonce".into()))?;
|
||||
let nonce_bytes = nonce::build_nonce(&self.session_id, header.seq, Direction::Send);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
use chacha20poly1305::aead::Payload;
|
||||
@@ -108,8 +167,21 @@ impl CryptoSession for ChaChaSession {
|
||||
.decrypt(nonce, payload)
|
||||
.map_err(|_| CryptoError::DecryptionFailed)?;
|
||||
|
||||
let plaintext_len = plaintext.len();
|
||||
out.extend_from_slice(&plaintext);
|
||||
self.recv_seq = self.recv_seq.wrapping_add(1);
|
||||
self.recv_seq = self.recv_seq.wrapping_add(1); // packet counter for rekey trigger only
|
||||
|
||||
// Anti-replay check: header already parsed above.
|
||||
let window = self
|
||||
.anti_replay
|
||||
.entry((header.stream_id, header.media_type))
|
||||
.or_insert_with(|| default_window_for_media_type(header.media_type));
|
||||
if let Err(e) = window.check_and_update(header.seq) {
|
||||
// Roll back the plaintext we just appended.
|
||||
out.truncate(out.len() - plaintext_len);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -127,38 +199,64 @@ impl CryptoSession for ChaChaSession {
|
||||
.ok_or_else(|| CryptoError::RekeyFailed("no pending rekey".into()))?;
|
||||
|
||||
let total_packets = self.send_seq as u64 + self.recv_seq as u64;
|
||||
let new_key = self.rekey_mgr.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
||||
let new_key = self
|
||||
.rekey_mgr
|
||||
.perform_rekey(peer_ephemeral_pub, secret, total_packets);
|
||||
self.install_key(new_key);
|
||||
|
||||
// Reset sequence counters after rekey for nonce uniqueness
|
||||
// Reset sequence counters after rekey for nonce uniqueness.
|
||||
// last_encrypt_timestamp is intentionally NOT reset — spec requires
|
||||
// timestamp_ms to be monotonic across rekeys.
|
||||
self.send_seq = 0;
|
||||
self.recv_seq = 0;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sas_code(&self) -> Option<u32> {
|
||||
self.sas_code
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
fn make_session_pair() -> (ChaChaSession, ChaChaSession) {
|
||||
let key = [0x42u8; 32];
|
||||
(ChaChaSession::new(key), ChaChaSession::new(key))
|
||||
}
|
||||
|
||||
/// Build a minimal valid v2 MediaHeader serialised to bytes.
|
||||
fn make_header_bytes(seq: u32) -> Vec<u8> {
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq.wrapping_mul(20),
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut bytes = Vec::new();
|
||||
header.write_to(&mut bytes);
|
||||
bytes
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"test-header";
|
||||
let header = make_header_bytes(0);
|
||||
let plaintext = b"hello warzone";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||
|
||||
// Bob decrypts (his recv matches Alice's send)
|
||||
let mut decrypted = Vec::new();
|
||||
bob.decrypt(header, &ciphertext, &mut decrypted).unwrap();
|
||||
bob.decrypt(&header, &ciphertext, &mut decrypted).unwrap();
|
||||
|
||||
assert_eq!(&decrypted, plaintext);
|
||||
}
|
||||
@@ -166,14 +264,18 @@ mod tests {
|
||||
#[test]
|
||||
fn decrypt_wrong_aad_fails() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"correct-header";
|
||||
let correct_header = make_header_bytes(0);
|
||||
// Different seq → different nonce AND different AAD bytes: decryption must fail.
|
||||
let wrong_header = make_header_bytes(1);
|
||||
let plaintext = b"secret data";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice
|
||||
.encrypt(&correct_header, plaintext, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
let result = bob.decrypt(b"wrong-header", &ciphertext, &mut decrypted);
|
||||
let result = bob.decrypt(&wrong_header, &ciphertext, &mut decrypted);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -182,29 +284,29 @@ mod tests {
|
||||
let mut alice = ChaChaSession::new([0xAA; 32]);
|
||||
let mut eve = ChaChaSession::new([0xBB; 32]);
|
||||
|
||||
let header = b"hdr";
|
||||
let header = make_header_bytes(0);
|
||||
let plaintext = b"secret";
|
||||
|
||||
let mut ciphertext = Vec::new();
|
||||
alice.encrypt(header, plaintext, &mut ciphertext).unwrap();
|
||||
alice.encrypt(&header, plaintext, &mut ciphertext).unwrap();
|
||||
|
||||
let mut decrypted = Vec::new();
|
||||
let result = eve.decrypt(header, &ciphertext, &mut decrypted);
|
||||
let result = eve.decrypt(&header, &ciphertext, &mut decrypted);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_packets_roundtrip() {
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = b"hdr";
|
||||
|
||||
for i in 0..100 {
|
||||
for i in 0..100u32 {
|
||||
let header = make_header_bytes(i);
|
||||
let msg = format!("message {}", i);
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(header, msg.as_bytes(), &mut ct).unwrap();
|
||||
alice.encrypt(&header, msg.as_bytes(), &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(header, &ct, &mut pt).unwrap();
|
||||
bob.decrypt(&header, &ct, &mut pt).unwrap();
|
||||
assert_eq!(pt, msg.as_bytes());
|
||||
}
|
||||
}
|
||||
@@ -223,4 +325,140 @@ mod tests {
|
||||
// Session is now rekeyed - counters reset
|
||||
assert_eq!(alice.send_seq, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_survives_out_of_order_delivery() {
|
||||
// Regression test for nonce derivation using recv_seq instead of
|
||||
// MediaHeader.seq. If nonces are tied to a local counter, any reorder
|
||||
// causes the counter to diverge from the sender's seq and every
|
||||
// subsequent packet fails decryption permanently.
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let key = [0x55u8; 32];
|
||||
let mut alice = ChaChaSession::new(key);
|
||||
let mut bob = ChaChaSession::new(key);
|
||||
|
||||
let plaintext = b"audio payload";
|
||||
|
||||
// Encrypt 5 packets in order (seqs 10, 11, 12, 13, 14).
|
||||
let seqs = [10u32, 11, 12, 13, 14];
|
||||
let mut ciphertexts: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
|
||||
for &seq in &seqs {
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
ciphertexts.push((header_bytes, ct));
|
||||
}
|
||||
|
||||
// Bob receives them out of order: 0, 2, 1, 4, 3
|
||||
let delivery_order = [0usize, 2, 1, 4, 3];
|
||||
for &idx in &delivery_order {
|
||||
let (ref hdr, ref ct) = ciphertexts[idx];
|
||||
let mut pt = Vec::new();
|
||||
let result = bob.decrypt(hdr, ct, &mut pt);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"out-of-order packet (original idx={idx}, seq={}) must decrypt successfully",
|
||||
seqs[idx]
|
||||
);
|
||||
assert_eq!(&pt, plaintext);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_stream_anti_replay_rejects_duplicate() {
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 0,
|
||||
fec_ratio: 10,
|
||||
seq: 42,
|
||||
timestamp: 1000,
|
||||
fec_block: 0,
|
||||
};
|
||||
let mut header_bytes = Vec::new();
|
||||
header.write_to(&mut header_bytes);
|
||||
|
||||
let plaintext = b"audio frame";
|
||||
|
||||
// First packet decrypts successfully
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||
assert_eq!(&pt, plaintext);
|
||||
|
||||
// Exact duplicate is rejected by anti-replay
|
||||
let mut pt2 = Vec::new();
|
||||
let result = bob.decrypt(&header_bytes, &ct, &mut pt2);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"duplicate packet with same seq must be rejected"
|
||||
);
|
||||
assert!(pt2.is_empty(), "plaintext must be rolled back on replay");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_stream_anti_replay_video_burst_200_with_reorder() {
|
||||
use wzp_proto::{CodecId, MediaType};
|
||||
|
||||
let (mut alice, mut bob) = make_session_pair();
|
||||
let header = MediaHeader {
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Video,
|
||||
codec_id: CodecId::Opus24k,
|
||||
stream_id: 1,
|
||||
fec_ratio: 10,
|
||||
seq: 0,
|
||||
timestamp: 0,
|
||||
fec_block: 0,
|
||||
};
|
||||
|
||||
let plaintext = b"video frame";
|
||||
|
||||
// Send 200 packets in order
|
||||
for i in 0..200 {
|
||||
let mut h = header;
|
||||
h.seq = i;
|
||||
let mut header_bytes = Vec::new();
|
||||
h.write_to(&mut header_bytes);
|
||||
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
bob.decrypt(&header_bytes, &ct, &mut pt).unwrap();
|
||||
}
|
||||
|
||||
// Re-send packet 50 — should be rejected as replay
|
||||
let mut h = header;
|
||||
h.seq = 50;
|
||||
let mut header_bytes = Vec::new();
|
||||
h.write_to(&mut header_bytes);
|
||||
|
||||
let mut ct = Vec::new();
|
||||
alice.encrypt(&header_bytes, plaintext, &mut ct).unwrap();
|
||||
|
||||
let mut pt = Vec::new();
|
||||
let result = bob.decrypt(&header_bytes, &ct, &mut pt);
|
||||
assert!(result.is_err(), "reordered duplicate must be rejected");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//! 3. Auth: WZP auth module request/response matches FC's /v1/auth/validate contract
|
||||
//! 4. Mnemonic: BIP39 interop between both implementations
|
||||
|
||||
use wzp_proto::KeyExchange;
|
||||
use wzp_proto::{KeyExchange, default_signal_version};
|
||||
|
||||
// ─── Identity Compatibility (WZP-FC-8) ──────────────────────────────────────
|
||||
|
||||
@@ -52,7 +52,10 @@ fn wzp_identity_module_matches_featherchat() {
|
||||
assert_eq!(wzp_pub.signing.as_bytes(), fc_pub.signing.as_bytes());
|
||||
assert_eq!(wzp_pub.encryption.as_bytes(), fc_pub.encryption.as_bytes());
|
||||
assert_eq!(wzp_pub.fingerprint.0, fc_pub.fingerprint.0);
|
||||
assert_eq!(wzp_pub.fingerprint.to_string(), fc_pub.fingerprint.to_string());
|
||||
assert_eq!(
|
||||
wzp_pub.fingerprint.to_string(),
|
||||
fc_pub.fingerprint.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -111,10 +114,14 @@ fn mnemonic_strings_identical() {
|
||||
fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
// WZP creates a CallOffer SignalMessage
|
||||
let offer = wzp_proto::SignalMessage::CallOffer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [1u8; 32],
|
||||
ephemeral_pub: [2u8; 32],
|
||||
signature: vec![3u8; 64],
|
||||
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
};
|
||||
|
||||
// Encode as featherChat CallSignal payload
|
||||
@@ -147,16 +154,25 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
// And deserializes back
|
||||
let decoded: warzone_protocol::message::WireMessage = bincode::deserialize(&encoded).unwrap();
|
||||
if let warzone_protocol::message::WireMessage::CallSignal {
|
||||
id, payload: p, signal_type, ..
|
||||
id,
|
||||
payload: p,
|
||||
signal_type,
|
||||
..
|
||||
} = decoded
|
||||
{
|
||||
assert_eq!(id, "call-123");
|
||||
assert!(matches!(signal_type, warzone_protocol::message::CallSignalType::Offer));
|
||||
assert!(matches!(
|
||||
signal_type,
|
||||
warzone_protocol::message::CallSignalType::Offer
|
||||
));
|
||||
|
||||
// Decode the WZP payload back
|
||||
let wzp_payload = wzp_client::featherchat::decode_call_payload(&p).unwrap();
|
||||
assert_eq!(wzp_payload.relay_addr.unwrap(), "relay.example.com:4433");
|
||||
assert!(matches!(wzp_payload.signal, wzp_proto::SignalMessage::CallOffer { .. }));
|
||||
assert!(matches!(
|
||||
wzp_payload.signal,
|
||||
wzp_proto::SignalMessage::CallOffer { .. }
|
||||
));
|
||||
} else {
|
||||
panic!("expected CallSignal");
|
||||
}
|
||||
@@ -165,6 +181,7 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
|
||||
#[test]
|
||||
fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||
let answer = wzp_proto::SignalMessage::CallAnswer {
|
||||
version: default_signal_version(),
|
||||
identity_pub: [10u8; 32],
|
||||
ephemeral_pub: [20u8; 32],
|
||||
signature: vec![30u8; 64],
|
||||
@@ -197,12 +214,17 @@ fn wzp_answer_round_trips_through_fc_callsignal() {
|
||||
#[test]
|
||||
fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||
let hangup = wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
};
|
||||
|
||||
let payload = wzp_client::featherchat::encode_call_payload(&hangup, None, None);
|
||||
let signal_type = wzp_client::featherchat::signal_to_call_type(&hangup);
|
||||
assert!(matches!(signal_type, wzp_client::featherchat::CallSignalType::Hangup));
|
||||
assert!(matches!(
|
||||
signal_type,
|
||||
wzp_client::featherchat::CallSignalType::Hangup
|
||||
));
|
||||
|
||||
let fc_msg = warzone_protocol::message::WireMessage::CallSignal {
|
||||
id: "call-789".to_string(),
|
||||
@@ -217,7 +239,10 @@ fn wzp_hangup_round_trips_through_fc_callsignal() {
|
||||
|
||||
if let warzone_protocol::message::WireMessage::CallSignal { payload, .. } = decoded {
|
||||
let wzp = wzp_client::featherchat::decode_call_payload(&payload).unwrap();
|
||||
assert!(matches!(wzp.signal, wzp_proto::SignalMessage::Hangup { .. }));
|
||||
assert!(matches!(
|
||||
wzp.signal,
|
||||
wzp_proto::SignalMessage::Hangup { .. }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,8 +275,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
||||
"eth_address": null
|
||||
});
|
||||
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||
assert!(wzp_resp.valid);
|
||||
assert_eq!(
|
||||
wzp_resp.fingerprint.unwrap(),
|
||||
@@ -263,8 +287,7 @@ fn auth_validate_response_matches_wzp_expectations() {
|
||||
#[test]
|
||||
fn auth_invalid_response_matches() {
|
||||
let fc_response = serde_json::json!({ "valid": false });
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(fc_response).unwrap();
|
||||
let wzp_resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(fc_response).unwrap();
|
||||
assert!(!wzp_resp.valid);
|
||||
assert!(wzp_resp.fingerprint.is_none());
|
||||
}
|
||||
@@ -273,19 +296,27 @@ fn auth_invalid_response_matches() {
|
||||
|
||||
#[test]
|
||||
fn all_signal_types_map_correctly() {
|
||||
use wzp_client::featherchat::{signal_to_call_type, CallSignalType};
|
||||
use wzp_client::featherchat::signal_to_call_type;
|
||||
|
||||
let cases: Vec<(wzp_proto::SignalMessage, &str)> = vec![
|
||||
(
|
||||
wzp_proto::SignalMessage::CallOffer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
signature: vec![], supported_profiles: vec![],
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
supported_profiles: vec![],
|
||||
alias: None,
|
||||
protocol_version: 2,
|
||||
supported_versions: vec![2],
|
||||
},
|
||||
"Offer",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::CallAnswer {
|
||||
identity_pub: [0; 32], ephemeral_pub: [0; 32],
|
||||
version: default_signal_version(),
|
||||
identity_pub: [0; 32],
|
||||
ephemeral_pub: [0; 32],
|
||||
signature: vec![],
|
||||
chosen_profile: wzp_proto::QualityProfile::GOOD,
|
||||
},
|
||||
@@ -293,13 +324,16 @@ fn all_signal_types_map_correctly() {
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::IceCandidate {
|
||||
version: default_signal_version(),
|
||||
candidate: "candidate:1".to_string(),
|
||||
},
|
||||
"IceCandidate",
|
||||
),
|
||||
(
|
||||
wzp_proto::SignalMessage::Hangup {
|
||||
version: default_signal_version(),
|
||||
reason: wzp_proto::HangupReason::Normal,
|
||||
call_id: None,
|
||||
},
|
||||
"Hangup",
|
||||
),
|
||||
@@ -308,7 +342,10 @@ fn all_signal_types_map_correctly() {
|
||||
for (signal, expected_name) in cases {
|
||||
let ct = signal_to_call_type(&signal);
|
||||
let name = format!("{ct:?}");
|
||||
assert_eq!(name, expected_name, "signal type mapping for {expected_name}");
|
||||
assert_eq!(
|
||||
name, expected_name,
|
||||
"signal type mapping for {expected_name}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,8 +459,7 @@ fn auth_response_with_eth_address() {
|
||||
"alias": "vitalik",
|
||||
"eth_address": "0x1234567890abcdef1234567890abcdef12345678"
|
||||
});
|
||||
let resp: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_eth).unwrap();
|
||||
let resp: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_eth).unwrap();
|
||||
assert!(resp.valid);
|
||||
assert_eq!(
|
||||
resp.fingerprint.unwrap(),
|
||||
@@ -438,8 +474,7 @@ fn auth_response_with_eth_address() {
|
||||
"alias": "anon",
|
||||
"eth_address": null
|
||||
});
|
||||
let resp2: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(with_null_eth).unwrap();
|
||||
let resp2: wzp_relay::auth::ValidateResponse = serde_json::from_value(with_null_eth).unwrap();
|
||||
assert!(resp2.valid);
|
||||
assert_eq!(
|
||||
resp2.fingerprint.unwrap(),
|
||||
@@ -450,15 +485,15 @@ fn auth_response_with_eth_address() {
|
||||
let without_eth = serde_json::json!({
|
||||
"valid": false
|
||||
});
|
||||
let resp3: wzp_relay::auth::ValidateResponse =
|
||||
serde_json::from_value(without_eth).unwrap();
|
||||
let resp3: wzp_relay::auth::ValidateResponse = serde_json::from_value(without_eth).unwrap();
|
||||
assert!(!resp3.valid);
|
||||
}
|
||||
|
||||
/// WZP-S-7: SignalMessage::AuthToken { token } exists and round-trips via serde.
|
||||
/// WZP-S-7: SignalMessage::AuthToken { version: default_signal_version(), token } exists and round-trips via serde.
|
||||
#[test]
|
||||
fn wzp_proto_has_auth_token_variant() {
|
||||
let msg = wzp_proto::SignalMessage::AuthToken {
|
||||
version: default_signal_version(),
|
||||
token: "fc-bearer-token-xyz".to_string(),
|
||||
};
|
||||
|
||||
@@ -469,7 +504,7 @@ fn wzp_proto_has_auth_token_variant() {
|
||||
|
||||
// Deserialize back
|
||||
let decoded: wzp_proto::SignalMessage = serde_json::from_str(&json).unwrap();
|
||||
if let wzp_proto::SignalMessage::AuthToken { token } = decoded {
|
||||
if let wzp_proto::SignalMessage::AuthToken { token, .. } = decoded {
|
||||
assert_eq!(token, "fc-bearer-token-xyz");
|
||||
} else {
|
||||
panic!("expected AuthToken variant, got: {decoded:?}");
|
||||
@@ -492,7 +527,11 @@ fn all_fc_call_signal_types_representable() {
|
||||
(CallSignalType::Busy, "Busy"),
|
||||
];
|
||||
|
||||
assert_eq!(variants.len(), 7, "featherChat defines exactly 7 call signal types");
|
||||
assert_eq!(
|
||||
variants.len(),
|
||||
7,
|
||||
"featherChat defines exactly 7 call signal types"
|
||||
);
|
||||
|
||||
for (variant, expected_name) in &variants {
|
||||
let name = format!("{variant:?}");
|
||||
@@ -546,10 +585,7 @@ fn hash_room_name_used_as_sni_is_valid() {
|
||||
#[test]
|
||||
fn wzp_proto_cargo_toml_is_standalone() {
|
||||
// Try both paths (run from workspace root or from crate directory)
|
||||
let candidates = [
|
||||
"crates/wzp-proto/Cargo.toml",
|
||||
"../wzp-proto/Cargo.toml",
|
||||
];
|
||||
let candidates = ["crates/wzp-proto/Cargo.toml", "../wzp-proto/Cargo.toml"];
|
||||
|
||||
let contents = candidates
|
||||
.iter()
|
||||
|
||||
@@ -13,11 +13,17 @@ pub struct AdaptiveFec {
|
||||
pub repair_ratio: f32,
|
||||
/// Symbol size in bytes.
|
||||
pub symbol_size: u16,
|
||||
/// Repair ratio to use when the block contains a keyframe.
|
||||
/// Default 0.5 (50% overhead) — keyframes are critical and worth
|
||||
/// the extra bandwidth.
|
||||
pub keyframe_repair_ratio: f32,
|
||||
}
|
||||
|
||||
impl AdaptiveFec {
|
||||
/// Default symbol size for adaptive configuration.
|
||||
const DEFAULT_SYMBOL_SIZE: u16 = 256;
|
||||
/// Default keyframe repair ratio (PRD-video-v1 T4.5).
|
||||
const DEFAULT_KEYFRAME_REPAIR_RATIO: f32 = 0.5;
|
||||
|
||||
/// Create an adaptive FEC configuration from a quality profile.
|
||||
///
|
||||
@@ -30,12 +36,15 @@ impl AdaptiveFec {
|
||||
frames_per_block: profile.frames_per_block as usize,
|
||||
repair_ratio: profile.fec_ratio,
|
||||
symbol_size: Self::DEFAULT_SYMBOL_SIZE,
|
||||
keyframe_repair_ratio: Self::DEFAULT_KEYFRAME_REPAIR_RATIO,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a configured FEC encoder from this adaptive configuration.
|
||||
pub fn build_encoder(&self) -> RaptorQFecEncoder {
|
||||
RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size)
|
||||
let mut enc = RaptorQFecEncoder::new(self.frames_per_block, self.symbol_size);
|
||||
enc.set_keyframe_ratio(self.keyframe_repair_ratio);
|
||||
enc
|
||||
}
|
||||
|
||||
/// Get the repair ratio for use with `FecEncoder::generate_repair()`.
|
||||
@@ -59,6 +68,7 @@ mod tests {
|
||||
let cfg = AdaptiveFec::from_profile(&QualityProfile::GOOD);
|
||||
assert_eq!(cfg.frames_per_block, 5);
|
||||
assert!((cfg.repair_ratio - 0.2).abs() < f32::EPSILON);
|
||||
assert!((cfg.keyframe_repair_ratio - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
//! RaptorQ FEC decoder — reassembles source blocks from received source and repair symbols.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockDecoder};
|
||||
use wzp_proto::error::FecError;
|
||||
use wzp_proto::FecDecoder;
|
||||
use wzp_proto::error::FecError;
|
||||
|
||||
/// Length prefix size (u16 little-endian), must match encoder.
|
||||
const LEN_PREFIX: usize = 2;
|
||||
|
||||
/// Decoded blocks older than this are eligible for reuse by a new sender.
|
||||
const BLOCK_STALE_SECS: u64 = 2;
|
||||
|
||||
/// State for one in-flight block being decoded.
|
||||
struct BlockState {
|
||||
/// Number of source symbols expected.
|
||||
@@ -21,6 +25,8 @@ struct BlockState {
|
||||
decoded: bool,
|
||||
/// Cached decoded result.
|
||||
result: Option<Vec<Vec<u8>>>,
|
||||
/// When this block was last decoded (for staleness check).
|
||||
decoded_at: Option<Instant>,
|
||||
}
|
||||
|
||||
/// RaptorQ-based FEC decoder that handles multiple concurrent blocks.
|
||||
@@ -58,6 +64,7 @@ impl RaptorQFecDecoder {
|
||||
symbol_size: self.symbol_size,
|
||||
decoded: false,
|
||||
result: None,
|
||||
decoded_at: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -66,7 +73,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
fn add_symbol(
|
||||
&mut self,
|
||||
block_id: u8,
|
||||
symbol_index: u8,
|
||||
symbol_index: u16,
|
||||
_is_repair: bool,
|
||||
data: &[u8],
|
||||
) -> Result<(), FecError> {
|
||||
@@ -74,8 +81,20 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
let block = self.get_or_create_block(block_id);
|
||||
|
||||
if block.decoded {
|
||||
// Already decoded, ignore additional symbols.
|
||||
return Ok(());
|
||||
// If the block was decoded recently, skip (normal duplicate).
|
||||
// If it's stale (>2s), a new sender is reusing this block_id — reset it.
|
||||
if let Some(at) = block.decoded_at {
|
||||
if at.elapsed().as_secs() >= BLOCK_STALE_SECS {
|
||||
block.decoded = false;
|
||||
block.result = None;
|
||||
block.decoded_at = None;
|
||||
block.packets.clear();
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Data should already be at symbol_size (length-prefixed and padded by the encoder).
|
||||
@@ -121,10 +140,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
frames.push(Vec::new());
|
||||
continue;
|
||||
}
|
||||
let payload_len = u16::from_le_bytes([
|
||||
data[offset],
|
||||
data[offset + 1],
|
||||
]) as usize;
|
||||
let payload_len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
|
||||
let payload_start = offset + LEN_PREFIX;
|
||||
let payload_end = (payload_start + payload_len).min(data.len());
|
||||
frames.push(data[payload_start..payload_end].to_vec());
|
||||
@@ -132,6 +148,7 @@ impl FecDecoder for RaptorQFecDecoder {
|
||||
|
||||
let block = self.blocks.get_mut(&block_id).unwrap();
|
||||
block.decoded = true;
|
||||
block.decoded_at = Some(Instant::now());
|
||||
block.result = Some(frames.clone());
|
||||
Ok(Some(frames))
|
||||
}
|
||||
@@ -178,9 +195,7 @@ mod tests {
|
||||
|
||||
// Feed all source symbols (using the length-prefixed padded data).
|
||||
for (i, pkt) in source_pkts.iter().enumerate() {
|
||||
decoder
|
||||
.add_symbol(0, i as u8, false, pkt.data())
|
||||
.unwrap();
|
||||
decoder.add_symbol(0, i as u16, false, pkt.data()).unwrap();
|
||||
}
|
||||
|
||||
let result = decoder.try_decode(0).unwrap();
|
||||
@@ -213,7 +228,11 @@ mod tests {
|
||||
let config = ObjectTransmissionInformation::new(block_len, SYMBOL_SIZE, 1, 1, 1);
|
||||
let mut dec = SourceBlockDecoder::new(0, &config, block_len);
|
||||
let decoded = dec.decode(all);
|
||||
assert!(decoded.is_some(), "Should recover with {:.0}% loss", drop_fraction * 100.0);
|
||||
assert!(
|
||||
decoded.is_some(),
|
||||
"Should recover with {:.0}% loss",
|
||||
drop_fraction * 100.0
|
||||
);
|
||||
|
||||
let data = decoded.unwrap();
|
||||
let ss = SYMBOL_SIZE as usize;
|
||||
@@ -225,13 +244,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_30pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3); }
|
||||
fn decode_with_30pct_loss() {
|
||||
run_loss_test(FRAMES_PER_BLOCK, 0.5, 0.3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_50pct_loss() { run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5); }
|
||||
fn decode_with_50pct_loss() {
|
||||
run_loss_test(FRAMES_PER_BLOCK, 1.0, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_with_70pct_source_loss_heavy_repair() { run_loss_test(8, 2.0, 0.5); }
|
||||
fn decode_with_70pct_source_loss_heavy_repair() {
|
||||
run_loss_test(8, 2.0, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expire_removes_old_blocks() {
|
||||
@@ -268,10 +293,10 @@ mod tests {
|
||||
// Interleave symbols from block 0 and block 1
|
||||
for i in 0..FRAMES_PER_BLOCK {
|
||||
decoder
|
||||
.add_symbol(0, i as u8, false, pkts_a[i].data())
|
||||
.add_symbol(0, i as u16, false, pkts_a[i].data())
|
||||
.unwrap();
|
||||
decoder
|
||||
.add_symbol(1, i as u8, false, pkts_b[i].data())
|
||||
.add_symbol(1, i as u16, false, pkts_b[i].data())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! RaptorQ FEC encoder — accumulates source symbols into blocks and generates repair symbols.
|
||||
|
||||
use raptorq::{EncodingPacket, ObjectTransmissionInformation, PayloadId, SourceBlockEncoder};
|
||||
use wzp_proto::error::FecError;
|
||||
use wzp_proto::FecEncoder;
|
||||
use wzp_proto::error::FecError;
|
||||
|
||||
/// Maximum symbol size in bytes. Audio frames are typically < 200 bytes,
|
||||
/// but we pad to a uniform size within a block.
|
||||
@@ -23,6 +23,11 @@ pub struct RaptorQFecEncoder {
|
||||
source_symbols: Vec<Vec<u8>>,
|
||||
/// Symbol size used for encoding (all symbols padded to this size).
|
||||
symbol_size: u16,
|
||||
/// True if at least one source symbol in the current block is a keyframe.
|
||||
has_keyframe: bool,
|
||||
/// Repair ratio to use when the block contains a keyframe.
|
||||
/// If zero, the nominal ratio passed to [`generate_repair`] is used.
|
||||
keyframe_ratio: f32,
|
||||
}
|
||||
|
||||
impl RaptorQFecEncoder {
|
||||
@@ -36,9 +41,26 @@ impl RaptorQFecEncoder {
|
||||
frames_per_block,
|
||||
source_symbols: Vec::with_capacity(frames_per_block),
|
||||
symbol_size,
|
||||
has_keyframe: false,
|
||||
keyframe_ratio: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the repair ratio to use for blocks that contain at least one
|
||||
/// keyframe source symbol.
|
||||
///
|
||||
/// When `keyframe_ratio > 0.0` and [`has_keyframe`](Self::has_keyframe)
|
||||
/// is true, [`generate_repair`](FecEncoder::generate_repair) uses this
|
||||
/// ratio instead of the nominal ratio passed by the caller.
|
||||
pub fn set_keyframe_ratio(&mut self, ratio: f32) {
|
||||
self.keyframe_ratio = ratio.max(0.0);
|
||||
}
|
||||
|
||||
/// Returns true if the current block contains a keyframe source symbol.
|
||||
pub fn has_keyframe(&self) -> bool {
|
||||
self.has_keyframe
|
||||
}
|
||||
|
||||
/// Create with default symbol size (256 bytes).
|
||||
pub fn with_defaults(frames_per_block: usize) -> Self {
|
||||
Self::new(frames_per_block, DEFAULT_MAX_SYMBOL_SIZE)
|
||||
@@ -54,8 +76,7 @@ impl RaptorQFecEncoder {
|
||||
let payload_len = sym.len().min(max_payload);
|
||||
let offset = i * ss;
|
||||
// Write 2-byte little-endian length prefix.
|
||||
data[offset..offset + LEN_PREFIX]
|
||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
// Write payload after prefix.
|
||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||
.copy_from_slice(&sym[..payload_len]);
|
||||
@@ -75,17 +96,36 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u8, Vec<u8>)>, FecError> {
|
||||
fn add_source_symbol_with_keyframe(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
is_keyframe: bool,
|
||||
) -> Result<(), FecError> {
|
||||
self.add_source_symbol(data)?;
|
||||
if is_keyframe {
|
||||
self.has_keyframe = true;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_repair(&mut self, ratio: f32) -> Result<Vec<(u16, Vec<u8>)>, FecError> {
|
||||
if self.source_symbols.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let effective_ratio = if self.has_keyframe && self.keyframe_ratio > 0.0 {
|
||||
self.keyframe_ratio
|
||||
} else {
|
||||
ratio
|
||||
};
|
||||
|
||||
let block_data = self.build_block_data();
|
||||
let config = ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
||||
let config =
|
||||
ObjectTransmissionInformation::with_defaults(block_data.len() as u64, self.symbol_size);
|
||||
let encoder = SourceBlockEncoder::new(self.block_id, &config, &block_data);
|
||||
|
||||
let num_source = self.source_symbols.len() as u32;
|
||||
let num_repair = ((num_source as f32) * ratio).ceil() as u32;
|
||||
let num_repair = ((num_source as f32) * effective_ratio).ceil() as u32;
|
||||
if num_repair == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
@@ -93,11 +133,11 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
// Generate repair packets starting from offset 0 (ESIs begin at num_source).
|
||||
let repair_packets: Vec<EncodingPacket> = encoder.repair_packets(0, num_repair);
|
||||
|
||||
let result: Vec<(u8, Vec<u8>)> = repair_packets
|
||||
let result: Vec<(u16, Vec<u8>)> = repair_packets
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, pkt): (usize, EncodingPacket)| {
|
||||
let idx = (num_source as u8).wrapping_add(i as u8);
|
||||
let idx = (num_source as u16).wrapping_add(i as u16);
|
||||
(idx, pkt.data().to_vec())
|
||||
})
|
||||
.collect();
|
||||
@@ -109,6 +149,7 @@ impl FecEncoder for RaptorQFecEncoder {
|
||||
let completed = self.block_id;
|
||||
self.block_id = self.block_id.wrapping_add(1);
|
||||
self.source_symbols.clear();
|
||||
self.has_keyframe = false;
|
||||
Ok(completed)
|
||||
}
|
||||
|
||||
@@ -130,8 +171,7 @@ fn build_prefixed_block_data(symbols: &[Vec<u8>], symbol_size: u16) -> Vec<u8> {
|
||||
let max_payload = ss - LEN_PREFIX;
|
||||
let payload_len = sym.len().min(max_payload);
|
||||
let offset = i * ss;
|
||||
data[offset..offset + LEN_PREFIX]
|
||||
.copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset..offset + LEN_PREFIX].copy_from_slice(&(payload_len as u16).to_le_bytes());
|
||||
data[offset + LEN_PREFIX..offset + LEN_PREFIX + payload_len]
|
||||
.copy_from_slice(&sym[..payload_len]);
|
||||
}
|
||||
@@ -211,4 +251,54 @@ mod tests {
|
||||
// After 256 blocks, wraps back to 0
|
||||
assert_eq!(enc.current_block_id(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyframe_boost_uses_higher_ratio() {
|
||||
// Non-keyframe block with nominal ratio 0.2 → ceil(5 * 0.2) = 1 repair.
|
||||
let mut enc_normal = RaptorQFecEncoder::with_defaults(5);
|
||||
enc_normal.set_keyframe_ratio(0.8);
|
||||
for i in 0..5 {
|
||||
enc_normal
|
||||
.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||
.unwrap();
|
||||
}
|
||||
let normal_repair = enc_normal.generate_repair(0.2).unwrap();
|
||||
assert_eq!(normal_repair.len(), 1);
|
||||
|
||||
// Keyframe block with same nominal ratio but boost to 0.8 → ceil(5 * 0.8) = 4 repairs.
|
||||
let mut enc_key = RaptorQFecEncoder::with_defaults(5);
|
||||
enc_key.set_keyframe_ratio(0.8);
|
||||
for i in 0..5 {
|
||||
enc_key
|
||||
.add_source_symbol_with_keyframe(&[i as u8; 100], i == 2)
|
||||
.unwrap();
|
||||
}
|
||||
let keyframe_repair = enc_key.generate_repair(0.2).unwrap();
|
||||
assert_eq!(keyframe_repair.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_keyframe_block_uses_nominal_ratio() {
|
||||
let mut enc = RaptorQFecEncoder::with_defaults(5);
|
||||
enc.set_keyframe_ratio(0.8);
|
||||
|
||||
for i in 0..5 {
|
||||
enc.add_source_symbol_with_keyframe(&[i as u8; 100], false)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let repair = enc.generate_repair(0.2).unwrap();
|
||||
assert_eq!(repair.len(), 1); // ceil(5 * 0.2) = 1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finalize_clears_keyframe_flag() {
|
||||
let mut enc = RaptorQFecEncoder::with_defaults(2);
|
||||
enc.add_source_symbol_with_keyframe(&[0u8; 10], true)
|
||||
.unwrap();
|
||||
assert!(enc.has_keyframe());
|
||||
|
||||
enc.finalize_block().unwrap();
|
||||
assert!(!enc.has_keyframe());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,7 +146,10 @@ mod tests {
|
||||
|
||||
// Each block should lose exactly 2 (6 losses / 3 blocks)
|
||||
for &loss in &losses_per_block {
|
||||
assert_eq!(loss, 2, "Each block should lose at most 2 symbols from a burst of 6");
|
||||
assert_eq!(
|
||||
loss, 2,
|
||||
"Each block should lose at most 2 symbols from a burst of 6"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ pub mod encoder;
|
||||
pub mod interleave;
|
||||
|
||||
pub use adaptive::AdaptiveFec;
|
||||
pub use block_manager::{DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState};
|
||||
pub use block_manager::{
|
||||
DecoderBlockManager, DecoderBlockState, EncoderBlockManager, EncoderBlockState,
|
||||
};
|
||||
pub use decoder::RaptorQFecDecoder;
|
||||
pub use encoder::RaptorQFecEncoder;
|
||||
pub use interleave::Interleaver;
|
||||
@@ -24,9 +26,7 @@ pub use interleave::Interleaver;
|
||||
pub use wzp_proto::{FecDecoder, FecEncoder, QualityProfile};
|
||||
|
||||
/// Create an encoder/decoder pair configured for the given quality profile.
|
||||
pub fn create_fec_pair(
|
||||
profile: &QualityProfile,
|
||||
) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
||||
pub fn create_fec_pair(profile: &QualityProfile) -> (RaptorQFecEncoder, RaptorQFecDecoder) {
|
||||
let cfg = AdaptiveFec::from_profile(profile);
|
||||
let encoder = cfg.build_encoder();
|
||||
let decoder = RaptorQFecDecoder::new(cfg.frames_per_block, cfg.symbol_size);
|
||||
|
||||
29
crates/wzp-native/Cargo.toml
Normal file
29
crates/wzp-native/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "wzp-native"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
description = "WarzonePhone native audio library — standalone Android cdylib that eventually owns all C++ (Oboe bridge) and exposes a pure-C FFI. Built with cargo-ndk, loaded at runtime by the Tauri desktop cdylib via libloading."
|
||||
|
||||
# Crate-type is DELIBERATELY only cdylib (no rlib, no staticlib). This crate
|
||||
# is built with `cargo ndk -t arm64-v8a build --release -p wzp-native` as a
|
||||
# standalone .so, which is the same path the legacy wzp-android crate uses
|
||||
# successfully on the same phone / same NDK. Keeping the crate-type single
|
||||
# avoids the rust-lang/rust#104707 symbol leak that bit us when Tauri's
|
||||
# desktop crate had ["staticlib", "cdylib", "rlib"] and any C++ static
|
||||
# archive pulled bionic's internal pthread_create into the final .so.
|
||||
[lib]
|
||||
name = "wzp_native"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[build-dependencies]
|
||||
# cc is SAFE to use here because this crate is a single-cdylib: no
|
||||
# staticlib in crate-type → no rust-lang/rust#104707 symbol leak. The
|
||||
# legacy wzp-android crate uses the same setup and works.
|
||||
cc = "1"
|
||||
|
||||
[dependencies]
|
||||
# Phase 2: Oboe C++ audio bridge. Still no Rust deps — we do the whole
|
||||
# audio pipeline via extern "C" into the bundled C++ and expose our own
|
||||
# narrow extern "C" API for wzp-desktop to dlopen via libloading.
|
||||
# Phase 3 can add wzp-proto/wzp-codec if we want to share codec logic
|
||||
# instead of calling back into wzp-desktop via callbacks.
|
||||
134
crates/wzp-native/build.rs
Normal file
134
crates/wzp-native/build.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
//! wzp-native build.rs — Oboe C++ bridge compile on Android.
|
||||
//!
|
||||
//! Near-verbatim copy of crates/wzp-android/build.rs (which is known to
|
||||
//! work). The crucial distinction: this crate is a single-cdylib (no
|
||||
//! staticlib, no rlib in crate-type) so rust-lang/rust#104707 doesn't
|
||||
//! apply — bionic's internal pthread_create / __init_tcb symbols stay
|
||||
//! UND and resolve against libc.so at runtime, as they should.
|
||||
//!
|
||||
//! On non-Android hosts we compile `cpp/oboe_stub.cpp` (empty stubs) so
|
||||
//! `cargo check --target <host>` still works for IDEs and CI.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() {
|
||||
let target = std::env::var("TARGET").unwrap_or_default();
|
||||
|
||||
if target.contains("android") {
|
||||
// getauxval_fix: override compiler-rt's broken static getauxval
|
||||
// stub that SIGSEGVs in shared libraries.
|
||||
cc::Build::new()
|
||||
.file("cpp/getauxval_fix.c")
|
||||
.compile("wzp_native_getauxval_fix");
|
||||
|
||||
let oboe_dir = fetch_oboe();
|
||||
match oboe_dir {
|
||||
Some(oboe_path) => {
|
||||
println!(
|
||||
"cargo:warning=wzp-native: building with Oboe from {:?}",
|
||||
oboe_path
|
||||
);
|
||||
let mut build = cc::Build::new();
|
||||
build
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
// Shared libc++ — matches legacy wzp-android setup.
|
||||
.cpp_link_stdlib(Some("c++_shared"))
|
||||
.include("cpp")
|
||||
.include(oboe_path.join("include"))
|
||||
.include(oboe_path.join("src"))
|
||||
.define("WZP_HAS_OBOE", None)
|
||||
.file("cpp/oboe_bridge.cpp");
|
||||
add_cpp_files_recursive(&mut build, &oboe_path.join("src"));
|
||||
build.compile("wzp_native_oboe_bridge");
|
||||
}
|
||||
None => {
|
||||
println!("cargo:warning=wzp-native: Oboe not found, building stub");
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
.cpp_link_stdlib(Some("c++_shared"))
|
||||
.file("cpp/oboe_stub.cpp")
|
||||
.include("cpp")
|
||||
.compile("wzp_native_oboe_bridge");
|
||||
}
|
||||
}
|
||||
|
||||
// Oboe needs log + OpenSLES backends at runtime.
|
||||
println!("cargo:rustc-link-lib=log");
|
||||
println!("cargo:rustc-link-lib=OpenSLES");
|
||||
|
||||
// Re-run if any cpp file changes
|
||||
println!("cargo:rerun-if-changed=cpp/oboe_bridge.cpp");
|
||||
println!("cargo:rerun-if-changed=cpp/oboe_bridge.h");
|
||||
println!("cargo:rerun-if-changed=cpp/oboe_stub.cpp");
|
||||
println!("cargo:rerun-if-changed=cpp/getauxval_fix.c");
|
||||
} else {
|
||||
// Non-Android hosts: compile the empty stub so lib.rs's extern
|
||||
// declarations resolve when someone runs `cargo check` on macOS
|
||||
// or Linux without an NDK.
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.std("c++17")
|
||||
.file("cpp/oboe_stub.cpp")
|
||||
.include("cpp")
|
||||
.compile("wzp_native_oboe_bridge");
|
||||
println!("cargo:rerun-if-changed=cpp/oboe_stub.cpp");
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively add all `.cpp` files from a directory to a cc::Build.
|
||||
fn add_cpp_files_recursive(build: &mut cc::Build, dir: &std::path::Path) {
|
||||
if !dir.is_dir() {
|
||||
return;
|
||||
}
|
||||
for entry in std::fs::read_dir(dir).unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
add_cpp_files_recursive(build, &path);
|
||||
} else if path.extension().map_or(false, |e| e == "cpp") {
|
||||
build.file(&path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch or find Oboe headers + sources (v1.8.1). Same logic as the
|
||||
/// legacy wzp-android crate's build.rs.
|
||||
fn fetch_oboe() -> Option<PathBuf> {
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
let oboe_dir = out_dir.join("oboe");
|
||||
|
||||
if oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists()
|
||||
{
|
||||
return Some(oboe_dir);
|
||||
}
|
||||
|
||||
let status = std::process::Command::new("git")
|
||||
.args([
|
||||
"clone",
|
||||
"--depth=1",
|
||||
"--branch=1.8.1",
|
||||
"https://github.com/google/oboe.git",
|
||||
oboe_dir.to_str().unwrap(),
|
||||
])
|
||||
.status();
|
||||
|
||||
match status {
|
||||
Ok(s)
|
||||
if s.success()
|
||||
&& oboe_dir
|
||||
.join("include")
|
||||
.join("oboe")
|
||||
.join("Oboe.h")
|
||||
.exists() =>
|
||||
{
|
||||
Some(oboe_dir)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
21
crates/wzp-native/cpp/getauxval_fix.c
Normal file
21
crates/wzp-native/cpp/getauxval_fix.c
Normal file
@@ -0,0 +1,21 @@
|
||||
// Override the broken static getauxval from compiler-rt/CRT.
|
||||
// The static version reads from __libc_auxv which is NULL in shared libs
|
||||
// loaded via dlopen, causing SIGSEGV in init_have_lse_atomics at load time.
|
||||
// This version calls the real bionic getauxval via dlsym.
|
||||
#ifdef __ANDROID__
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
|
||||
typedef unsigned long (*getauxval_fn)(unsigned long);
|
||||
|
||||
unsigned long getauxval(unsigned long type) {
|
||||
static getauxval_fn real_getauxval = (getauxval_fn)0;
|
||||
if (!real_getauxval) {
|
||||
real_getauxval = (getauxval_fn)dlsym((void*)-1L /* RTLD_DEFAULT */, "getauxval");
|
||||
if (!real_getauxval) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return real_getauxval(type);
|
||||
}
|
||||
#endif
|
||||
491
crates/wzp-native/cpp/oboe_bridge.cpp
Normal file
491
crates/wzp-native/cpp/oboe_bridge.cpp
Normal file
@@ -0,0 +1,491 @@
|
||||
// Full Oboe implementation for Android
|
||||
// This file is compiled only when targeting Android
|
||||
|
||||
#include "oboe_bridge.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include <oboe/Oboe.h>
|
||||
#include <android/log.h>
|
||||
#include <cstring>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#define LOG_TAG "wzp-oboe"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ring buffer helpers (SPSC, lock-free)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static inline int32_t ring_available_read(const wzp_atomic_int* write_idx,
|
||||
const wzp_atomic_int* read_idx,
|
||||
int32_t capacity) {
|
||||
int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_acquire);
|
||||
int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed);
|
||||
int32_t avail = w - r;
|
||||
if (avail < 0) avail += capacity;
|
||||
return avail;
|
||||
}
|
||||
|
||||
static inline int32_t ring_available_write(const wzp_atomic_int* write_idx,
|
||||
const wzp_atomic_int* read_idx,
|
||||
int32_t capacity) {
|
||||
return capacity - 1 - ring_available_read(write_idx, read_idx, capacity);
|
||||
}
|
||||
|
||||
static inline void ring_write(int16_t* buf, int32_t capacity,
|
||||
wzp_atomic_int* write_idx, const wzp_atomic_int* read_idx,
|
||||
const int16_t* src, int32_t count) {
|
||||
int32_t w = std::atomic_load_explicit(write_idx, std::memory_order_relaxed);
|
||||
for (int32_t i = 0; i < count; i++) {
|
||||
buf[w] = src[i];
|
||||
w++;
|
||||
if (w >= capacity) w = 0;
|
||||
}
|
||||
std::atomic_store_explicit(write_idx, w, std::memory_order_release);
|
||||
}
|
||||
|
||||
static inline void ring_read(int16_t* buf, int32_t capacity,
|
||||
const wzp_atomic_int* write_idx, wzp_atomic_int* read_idx,
|
||||
int16_t* dst, int32_t count) {
|
||||
int32_t r = std::atomic_load_explicit(read_idx, std::memory_order_relaxed);
|
||||
for (int32_t i = 0; i < count; i++) {
|
||||
dst[i] = buf[r];
|
||||
r++;
|
||||
if (r >= capacity) r = 0;
|
||||
}
|
||||
std::atomic_store_explicit(read_idx, r, std::memory_order_release);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Global state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static std::shared_ptr<oboe::AudioStream> g_capture_stream;
|
||||
static std::shared_ptr<oboe::AudioStream> g_playout_stream;
|
||||
// Value copy — the WzpOboeRings the Rust side passes us lives on the caller's
|
||||
// stack frame and goes away as soon as wzp_oboe_start returns. The raw
|
||||
// int16/atomic pointers INSIDE the struct point into the Rust-owned, leaked-
|
||||
// for-the-lifetime-of-the-process AudioBackend singleton, so copying the
|
||||
// struct by value is safe and keeps the inner pointers valid indefinitely.
|
||||
// g_rings_valid guards the audio-callback-side read; clearing it in stop()
|
||||
// signals "no backend" to the callbacks which then return silence + Stop.
|
||||
static WzpOboeRings g_rings{};
|
||||
static std::atomic<bool> g_rings_valid{false};
|
||||
static std::atomic<bool> g_running{false};
|
||||
static std::atomic<float> g_capture_latency_ms{0.0f};
|
||||
static std::atomic<float> g_playout_latency_ms{0.0f};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Capture callback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class CaptureCallback : public oboe::AudioStreamDataCallback {
|
||||
public:
|
||||
uint64_t calls = 0;
|
||||
uint64_t total_frames = 0;
|
||||
uint64_t total_written = 0;
|
||||
uint64_t ring_full_drops = 0;
|
||||
|
||||
oboe::DataCallbackResult onAudioReady(
|
||||
oboe::AudioStream* stream,
|
||||
void* audioData,
|
||||
int32_t numFrames) override {
|
||||
if (!g_running.load(std::memory_order_relaxed) ||
|
||||
!g_rings_valid.load(std::memory_order_acquire)) {
|
||||
return oboe::DataCallbackResult::Stop;
|
||||
}
|
||||
|
||||
const int16_t* src = static_cast<const int16_t*>(audioData);
|
||||
int32_t avail = ring_available_write(g_rings.capture_write_idx,
|
||||
g_rings.capture_read_idx,
|
||||
g_rings.capture_capacity);
|
||||
int32_t to_write = (numFrames < avail) ? numFrames : avail;
|
||||
if (to_write > 0) {
|
||||
ring_write(g_rings.capture_buf, g_rings.capture_capacity,
|
||||
g_rings.capture_write_idx, g_rings.capture_read_idx,
|
||||
src, to_write);
|
||||
}
|
||||
total_frames += numFrames;
|
||||
total_written += to_write;
|
||||
if (to_write < numFrames) {
|
||||
ring_full_drops += (numFrames - to_write);
|
||||
}
|
||||
|
||||
// Sample-range probe on the FIRST callback to prove we get real audio
|
||||
if (calls == 0 && numFrames > 0) {
|
||||
int16_t lo = src[0], hi = src[0];
|
||||
int32_t sumsq = 0;
|
||||
for (int32_t i = 0; i < numFrames; i++) {
|
||||
if (src[i] < lo) lo = src[i];
|
||||
if (src[i] > hi) hi = src[i];
|
||||
sumsq += (int32_t)src[i] * (int32_t)src[i];
|
||||
}
|
||||
int32_t rms = (int32_t) (numFrames > 0 ? (int32_t)__builtin_sqrt((double)sumsq / (double)numFrames) : 0);
|
||||
LOGI("capture cb#0: numFrames=%d sample_range=[%d..%d] rms=%d to_write=%d",
|
||||
numFrames, lo, hi, rms, to_write);
|
||||
}
|
||||
// Heartbeat every 50 callbacks (~1s at 20ms/burst)
|
||||
calls++;
|
||||
if ((calls % 50) == 0) {
|
||||
LOGI("capture heartbeat: calls=%llu numFrames=%d ring_avail_write=%d to_write=%d full_drops=%llu total_written=%llu",
|
||||
(unsigned long long)calls, numFrames, avail, to_write,
|
||||
(unsigned long long)ring_full_drops, (unsigned long long)total_written);
|
||||
}
|
||||
|
||||
// Update latency estimate
|
||||
auto result = stream->calculateLatencyMillis();
|
||||
if (result) {
|
||||
g_capture_latency_ms.store(static_cast<float>(result.value()),
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
return oboe::DataCallbackResult::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Playout callback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class PlayoutCallback : public oboe::AudioStreamDataCallback {
|
||||
public:
|
||||
uint64_t calls = 0;
|
||||
uint64_t total_frames = 0;
|
||||
uint64_t total_played_real = 0;
|
||||
uint64_t underrun_frames = 0;
|
||||
uint64_t nonempty_calls = 0;
|
||||
|
||||
oboe::DataCallbackResult onAudioReady(
|
||||
oboe::AudioStream* stream,
|
||||
void* audioData,
|
||||
int32_t numFrames) override {
|
||||
if (!g_running.load(std::memory_order_relaxed) ||
|
||||
!g_rings_valid.load(std::memory_order_acquire)) {
|
||||
memset(audioData, 0, numFrames * sizeof(int16_t));
|
||||
return oboe::DataCallbackResult::Stop;
|
||||
}
|
||||
|
||||
int16_t* dst = static_cast<int16_t*>(audioData);
|
||||
int32_t avail = ring_available_read(g_rings.playout_write_idx,
|
||||
g_rings.playout_read_idx,
|
||||
g_rings.playout_capacity);
|
||||
int32_t to_read = (numFrames < avail) ? numFrames : avail;
|
||||
|
||||
if (to_read > 0) {
|
||||
ring_read(g_rings.playout_buf, g_rings.playout_capacity,
|
||||
g_rings.playout_write_idx, g_rings.playout_read_idx,
|
||||
dst, to_read);
|
||||
nonempty_calls++;
|
||||
}
|
||||
// Fill remainder with silence on underrun
|
||||
if (to_read < numFrames) {
|
||||
memset(dst + to_read, 0, (numFrames - to_read) * sizeof(int16_t));
|
||||
underrun_frames += (numFrames - to_read);
|
||||
}
|
||||
total_frames += numFrames;
|
||||
total_played_real += to_read;
|
||||
|
||||
// First callback: log requested config + prove we're being called
|
||||
if (calls == 0) {
|
||||
LOGI("playout cb#0: numFrames=%d ring_avail_read=%d to_read=%d",
|
||||
numFrames, avail, to_read);
|
||||
}
|
||||
// On the first callback that actually has data, log the sample range
|
||||
// so we can tell if the samples coming out of the ring look like real
|
||||
// audio vs constant-zeroes vs garbage.
|
||||
if (to_read > 0 && nonempty_calls == 1) {
|
||||
int16_t lo = dst[0], hi = dst[0];
|
||||
int32_t sumsq = 0;
|
||||
for (int32_t i = 0; i < to_read; i++) {
|
||||
if (dst[i] < lo) lo = dst[i];
|
||||
if (dst[i] > hi) hi = dst[i];
|
||||
sumsq += (int32_t)dst[i] * (int32_t)dst[i];
|
||||
}
|
||||
int32_t rms = (to_read > 0) ? (int32_t)__builtin_sqrt((double)sumsq / (double)to_read) : 0;
|
||||
LOGI("playout FIRST nonempty read: to_read=%d sample_range=[%d..%d] rms=%d",
|
||||
to_read, lo, hi, rms);
|
||||
}
|
||||
// Heartbeat every 50 callbacks (~1s at 20ms/burst)
|
||||
calls++;
|
||||
if ((calls % 50) == 0) {
|
||||
int state = (int)stream->getState();
|
||||
auto xrunRes = stream->getXRunCount();
|
||||
int xruns = xrunRes ? xrunRes.value() : -1;
|
||||
LOGI("playout heartbeat: calls=%llu nonempty=%llu numFrames=%d ring_avail_read=%d to_read=%d underrun_frames=%llu total_played_real=%llu state=%d xruns=%d",
|
||||
(unsigned long long)calls, (unsigned long long)nonempty_calls,
|
||||
numFrames, avail, to_read,
|
||||
(unsigned long long)underrun_frames, (unsigned long long)total_played_real,
|
||||
state, xruns);
|
||||
}
|
||||
|
||||
// Update latency estimate
|
||||
auto result = stream->calculateLatencyMillis();
|
||||
if (result) {
|
||||
g_playout_latency_ms.store(static_cast<float>(result.value()),
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
return oboe::DataCallbackResult::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
static CaptureCallback g_capture_cb;
|
||||
static PlayoutCallback g_playout_cb;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public C API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
if (g_running.load(std::memory_order_relaxed)) {
|
||||
LOGW("wzp_oboe_start: already running");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Deep-copy the rings struct into static storage BEFORE we publish it to
|
||||
// the audio callbacks — `rings` points at the caller's stack frame and
|
||||
// goes away as soon as this function returns.
|
||||
g_rings = *rings;
|
||||
g_rings_valid.store(true, std::memory_order_release);
|
||||
|
||||
// Build capture stream
|
||||
oboe::AudioStreamBuilder captureBuilder;
|
||||
captureBuilder.setDirection(oboe::Direction::Input)
|
||||
->setPerformanceMode(oboe::PerformanceMode::LowLatency)
|
||||
->setSharingMode(oboe::SharingMode::Shared)
|
||||
->setFormat(oboe::AudioFormat::I16)
|
||||
->setChannelCount(config->channel_count)
|
||||
->setSampleRateConversionQuality(oboe::SampleRateConversionQuality::Best)
|
||||
->setDataCallback(&g_capture_cb);
|
||||
|
||||
if (config->bt_active) {
|
||||
// BT SCO mode: do NOT set sample rate or input preset.
|
||||
// Requesting 48kHz against a BT SCO device fails with
|
||||
// "getInputProfile could not find profile". Letting the system
|
||||
// choose the native rate (8/16kHz) and relying on Oboe's
|
||||
// resampler (SampleRateConversionQuality::Best) to bridge
|
||||
// to our 48kHz ring buffer is the only path that works.
|
||||
// InputPreset::VoiceCommunication can also prevent BT SCO
|
||||
// routing on some devices — skip it for BT.
|
||||
LOGI("capture: BT mode — no sample rate or input preset set");
|
||||
} else {
|
||||
captureBuilder.setSampleRate(config->sample_rate)
|
||||
->setFramesPerDataCallback(config->frames_per_burst)
|
||||
->setInputPreset(oboe::InputPreset::VoiceCommunication);
|
||||
}
|
||||
|
||||
oboe::Result result = captureBuilder.openStream(g_capture_stream);
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to open capture stream: %s", oboe::convertToText(result));
|
||||
return -2;
|
||||
}
|
||||
LOGI("capture stream opened: actualSR=%d actualCh=%d actualFormat=%d actualFramesPerBurst=%d actualFramesPerDataCallback=%d bufferCapacityInFrames=%d sharing=%d perfMode=%d",
|
||||
g_capture_stream->getSampleRate(),
|
||||
g_capture_stream->getChannelCount(),
|
||||
(int)g_capture_stream->getFormat(),
|
||||
g_capture_stream->getFramesPerBurst(),
|
||||
g_capture_stream->getFramesPerDataCallback(),
|
||||
g_capture_stream->getBufferCapacityInFrames(),
|
||||
(int)g_capture_stream->getSharingMode(),
|
||||
(int)g_capture_stream->getPerformanceMode());
|
||||
|
||||
// Build playout stream.
|
||||
//
|
||||
// Regression triangulation between builds:
|
||||
// 96be740 (Usage::Media, default API): playout callback DID drain
|
||||
// the ring at steady 50Hz (playout heartbeat: calls=1100,
|
||||
// total_played_real=1055040). Audio not audible because OS routing
|
||||
// sent it to a silent output.
|
||||
//
|
||||
// 8c36fb5 (Usage::VoiceCommunication + setAudioApi(AAudio) +
|
||||
// ContentType::Speech): playout callback fired cb#0 once then
|
||||
// stopped draining the ring entirely. written_samples stuck at
|
||||
// ring capacity (7679) across all subsequent heartbeats, so Oboe
|
||||
// accepted zero samples after startup. Still inaudible.
|
||||
//
|
||||
// Hypothesis: forcing setAudioApi(AAudio) + VoiceCommunication on
|
||||
// Pixel 6 / Android 15 opens a stream that succeeds at cb#0 but
|
||||
// then detaches from the real audio driver. Reverting to the
|
||||
// config that at least drove callbacks correctly, plus the
|
||||
// Kotlin-side MODE_IN_COMMUNICATION + setSpeakerphoneOn(true)
|
||||
// handled in MainActivity.kt to route audio to the loud speaker.
|
||||
// Usage::VoiceCommunication is the correct Oboe usage for a VoIP app
|
||||
// — it respects Android's in-call audio routing and lets
|
||||
// AudioManager.setSpeakerphoneOn/setBluetoothScoOn actually switch
|
||||
// between earpiece, loudspeaker, and Bluetooth headset. Combined with
|
||||
// MODE_IN_COMMUNICATION set from MainActivity.kt and
|
||||
// speakerphoneOn=false by default, this produces handset/earpiece as
|
||||
// the default output.
|
||||
//
|
||||
// IMPORTANT: do NOT add setAudioApi(AAudio) here. Build 8c36fb5 proved
|
||||
// forcing AAudio with Usage::VoiceCommunication makes the playout
|
||||
// callback stop draining the ring after cb#0, even though the stream
|
||||
// opens successfully. Letting Oboe pick the API (which will be AAudio
|
||||
// on API ≥ 27 but via a different codepath) kept callbacks firing in
|
||||
// every other build.
|
||||
oboe::AudioStreamBuilder playoutBuilder;
|
||||
playoutBuilder.setDirection(oboe::Direction::Output)
|
||||
->setPerformanceMode(oboe::PerformanceMode::LowLatency)
|
||||
->setSharingMode(oboe::SharingMode::Shared)
|
||||
->setFormat(oboe::AudioFormat::I16)
|
||||
->setChannelCount(config->channel_count)
|
||||
->setSampleRateConversionQuality(oboe::SampleRateConversionQuality::Best)
|
||||
->setDataCallback(&g_playout_cb);
|
||||
|
||||
if (config->bt_active) {
|
||||
LOGI("playout: BT mode — no sample rate set, using Usage::Media");
|
||||
// Usage::Media instead of VoiceCommunication for BT output
|
||||
// to avoid conflicts with the communication device routing.
|
||||
playoutBuilder.setUsage(oboe::Usage::Media);
|
||||
} else {
|
||||
playoutBuilder.setSampleRate(config->sample_rate)
|
||||
->setFramesPerDataCallback(config->frames_per_burst)
|
||||
->setUsage(oboe::Usage::VoiceCommunication);
|
||||
}
|
||||
|
||||
result = playoutBuilder.openStream(g_playout_stream);
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to open playout stream: %s", oboe::convertToText(result));
|
||||
g_capture_stream->close();
|
||||
g_capture_stream.reset();
|
||||
return -3;
|
||||
}
|
||||
LOGI("playout stream opened: actualSR=%d actualCh=%d actualFormat=%d actualFramesPerBurst=%d actualFramesPerDataCallback=%d bufferCapacityInFrames=%d sharing=%d perfMode=%d",
|
||||
g_playout_stream->getSampleRate(),
|
||||
g_playout_stream->getChannelCount(),
|
||||
(int)g_playout_stream->getFormat(),
|
||||
g_playout_stream->getFramesPerBurst(),
|
||||
g_playout_stream->getFramesPerDataCallback(),
|
||||
g_playout_stream->getBufferCapacityInFrames(),
|
||||
(int)g_playout_stream->getSharingMode(),
|
||||
(int)g_playout_stream->getPerformanceMode());
|
||||
|
||||
g_running.store(true, std::memory_order_release);
|
||||
|
||||
// Start both streams
|
||||
result = g_capture_stream->requestStart();
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to start capture: %s", oboe::convertToText(result));
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -4;
|
||||
}
|
||||
|
||||
result = g_playout_stream->requestStart();
|
||||
if (result != oboe::Result::OK) {
|
||||
LOGE("Failed to start playout: %s", oboe::convertToText(result));
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_capture_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -5;
|
||||
}
|
||||
|
||||
// Log initial stream states right after requestStart() returns.
|
||||
// On well-behaved HALs both will already be Started; on others
|
||||
// (Nothing A059) they may still be in Starting state.
|
||||
LOGI("requestStart returned: capture_state=%d playout_state=%d",
|
||||
(int)g_capture_stream->getState(),
|
||||
(int)g_playout_stream->getState());
|
||||
|
||||
// Poll until both streams report Started state, up to 2s timeout.
|
||||
// Some Android HALs (Nothing A059) delay transitioning from Starting
|
||||
// to Started; proceeding before the transition completes causes the
|
||||
// first capture/playout callbacks to be dropped silently.
|
||||
{
|
||||
auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(2000);
|
||||
int poll_count = 0;
|
||||
bool streams_started = false;
|
||||
while (std::chrono::steady_clock::now() < deadline) {
|
||||
auto cap_state = g_capture_stream->getState();
|
||||
auto play_state = g_playout_stream->getState();
|
||||
if (cap_state == oboe::StreamState::Started &&
|
||||
play_state == oboe::StreamState::Started) {
|
||||
LOGI("both streams Started after %d polls", poll_count);
|
||||
streams_started = true;
|
||||
break;
|
||||
}
|
||||
poll_count++;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
// Log final state even on timeout (helps diagnose HAL quirks)
|
||||
LOGI("stream states after poll: capture=%d playout=%d (polls=%d)",
|
||||
(int)g_capture_stream->getState(),
|
||||
(int)g_playout_stream->getState(),
|
||||
poll_count);
|
||||
if (!streams_started) {
|
||||
LOGE("Timed out waiting for Oboe streams to reach Started state");
|
||||
g_running.store(false, std::memory_order_release);
|
||||
g_rings_valid.store(false, std::memory_order_release);
|
||||
g_capture_stream->requestStop();
|
||||
g_playout_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_playout_stream->close();
|
||||
g_capture_stream.reset();
|
||||
g_playout_stream.reset();
|
||||
return -6;
|
||||
}
|
||||
}
|
||||
|
||||
LOGI("Oboe started: sr=%d burst=%d ch=%d",
|
||||
config->sample_rate, config->frames_per_burst, config->channel_count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {
|
||||
g_running.store(false, std::memory_order_release);
|
||||
// Tell the audio callbacks to stop touching g_rings BEFORE we tear down
|
||||
// the streams, so any in-flight callback returns Stop instead of reading
|
||||
// stale pointers.
|
||||
g_rings_valid.store(false, std::memory_order_release);
|
||||
|
||||
if (g_capture_stream) {
|
||||
g_capture_stream->requestStop();
|
||||
g_capture_stream->close();
|
||||
g_capture_stream.reset();
|
||||
}
|
||||
if (g_playout_stream) {
|
||||
g_playout_stream->requestStop();
|
||||
g_playout_stream->close();
|
||||
g_playout_stream.reset();
|
||||
}
|
||||
|
||||
LOGI("Oboe stopped");
|
||||
}
|
||||
|
||||
float wzp_oboe_capture_latency_ms(void) {
|
||||
return g_capture_latency_ms.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
float wzp_oboe_playout_latency_ms(void) {
|
||||
return g_playout_latency_ms.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
int wzp_oboe_is_running(void) {
|
||||
return g_running.load(std::memory_order_relaxed) ? 1 : 0;
|
||||
}
|
||||
|
||||
#else
|
||||
// Non-Android fallback — should not be reached; oboe_stub.cpp is used instead.
|
||||
// Provide empty implementations just in case.
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
(void)config; (void)rings;
|
||||
return -99;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {}
|
||||
float wzp_oboe_capture_latency_ms(void) { return 0.0f; }
|
||||
float wzp_oboe_playout_latency_ms(void) { return 0.0f; }
|
||||
int wzp_oboe_is_running(void) { return 0; }
|
||||
|
||||
#endif // __ANDROID__
|
||||
44
crates/wzp-native/cpp/oboe_bridge.h
Normal file
44
crates/wzp-native/cpp/oboe_bridge.h
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef WZP_OBOE_BRIDGE_H
|
||||
#define WZP_OBOE_BRIDGE_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <atomic>
|
||||
typedef std::atomic<int32_t> wzp_atomic_int;
|
||||
extern "C" {
|
||||
#else
|
||||
#include <stdatomic.h>
|
||||
typedef atomic_int wzp_atomic_int;
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
int32_t sample_rate;
|
||||
int32_t frames_per_burst;
|
||||
int32_t channel_count;
|
||||
int32_t bt_active; /* nonzero = BT SCO mode: skip sample rate + input preset */
|
||||
} WzpOboeConfig;
|
||||
|
||||
typedef struct {
|
||||
int16_t* capture_buf;
|
||||
int32_t capture_capacity;
|
||||
wzp_atomic_int* capture_write_idx;
|
||||
wzp_atomic_int* capture_read_idx;
|
||||
|
||||
int16_t* playout_buf;
|
||||
int32_t playout_capacity;
|
||||
wzp_atomic_int* playout_write_idx;
|
||||
wzp_atomic_int* playout_read_idx;
|
||||
} WzpOboeRings;
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings);
|
||||
void wzp_oboe_stop(void);
|
||||
float wzp_oboe_capture_latency_ms(void);
|
||||
float wzp_oboe_playout_latency_ms(void);
|
||||
int wzp_oboe_is_running(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // WZP_OBOE_BRIDGE_H
|
||||
27
crates/wzp-native/cpp/oboe_stub.cpp
Normal file
27
crates/wzp-native/cpp/oboe_stub.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// Stub implementation for non-Android host builds (testing, cargo check, etc.)
|
||||
|
||||
#include "oboe_bridge.h"
|
||||
#include <stdio.h>
|
||||
|
||||
int wzp_oboe_start(const WzpOboeConfig* config, const WzpOboeRings* rings) {
|
||||
(void)config;
|
||||
(void)rings;
|
||||
fprintf(stderr, "wzp_oboe_start: stub (not on Android)\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
void wzp_oboe_stop(void) {
|
||||
fprintf(stderr, "wzp_oboe_stop: stub (not on Android)\n");
|
||||
}
|
||||
|
||||
float wzp_oboe_capture_latency_ms(void) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float wzp_oboe_playout_latency_ms(void) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
int wzp_oboe_is_running(void) {
|
||||
return 0;
|
||||
}
|
||||
501
crates/wzp-native/src/lib.rs
Normal file
501
crates/wzp-native/src/lib.rs
Normal file
@@ -0,0 +1,501 @@
|
||||
//! wzp-native — standalone Android cdylib for all the C++ audio code.
|
||||
//!
|
||||
//! Built with `cargo ndk`, NOT `cargo tauri android build`. Loaded at
|
||||
//! runtime by the Tauri desktop cdylib (`wzp-desktop`) via libloading.
|
||||
//! See `docs/incident-tauri-android-init-tcb.md` for why the split exists.
|
||||
//!
|
||||
//! Phase 2: real Oboe audio backend.
|
||||
//!
|
||||
//! Architecture: Oboe runs capture + playout streams on its own high-
|
||||
//! priority AAudio callback threads inside the C++ bridge. Two SPSC ring
|
||||
//! buffers (capture and playout) are shared between the C++ callbacks
|
||||
//! and the Rust side via atomic indices — no locks on the hot path.
|
||||
//! `wzp-desktop` drains the capture ring into its Opus encoder and fills
|
||||
//! the playout ring with decoded PCM.
|
||||
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
|
||||
// ─── Phase 1 smoke-test exports (kept for sanity checks) ─────────────────
|
||||
|
||||
/// Returns 42. Used by wzp-desktop's setup() to verify dlopen + dlsym
|
||||
/// work before any audio code runs.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_version() -> i32 {
|
||||
42
|
||||
}
|
||||
|
||||
/// Writes a NUL-terminated string into `out` (capped at `cap`) and
|
||||
/// returns bytes written excluding the NUL.
|
||||
///
|
||||
/// # Safety
|
||||
/// `out` must be a valid pointer to at least `cap` contiguous bytes of
|
||||
/// writable memory. Passing a null pointer or zero capacity is safe
|
||||
/// (returns 0), but a dangling non-null pointer is undefined behaviour.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_hello(out: *mut u8, cap: usize) -> usize {
|
||||
const MSG: &[u8] = b"hello from wzp-native\0";
|
||||
if out.is_null() || cap == 0 {
|
||||
return 0;
|
||||
}
|
||||
let n = MSG.len().min(cap);
|
||||
unsafe {
|
||||
core::ptr::copy_nonoverlapping(MSG.as_ptr(), out, n);
|
||||
*out.add(n - 1) = 0;
|
||||
}
|
||||
n - 1
|
||||
}
|
||||
|
||||
// ─── C++ Oboe bridge FFI ─────────────────────────────────────────────────
|
||||
|
||||
#[repr(C)]
|
||||
struct WzpOboeConfig {
|
||||
sample_rate: i32,
|
||||
frames_per_burst: i32,
|
||||
channel_count: i32,
|
||||
/// When nonzero, capture stream skips setSampleRate and setInputPreset
|
||||
/// so the system can route to BT SCO at its native rate (8/16kHz).
|
||||
/// Oboe's SampleRateConversionQuality::Best resamples to 48kHz.
|
||||
bt_active: i32,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct WzpOboeRings {
|
||||
capture_buf: *mut i16,
|
||||
capture_capacity: i32,
|
||||
capture_write_idx: *mut AtomicI32,
|
||||
capture_read_idx: *mut AtomicI32,
|
||||
playout_buf: *mut i16,
|
||||
playout_capacity: i32,
|
||||
playout_write_idx: *mut AtomicI32,
|
||||
playout_read_idx: *mut AtomicI32,
|
||||
}
|
||||
|
||||
// SAFETY: atomics synchronise producer/consumer; raw pointers are owned
|
||||
// by the AudioBackend singleton below whose lifetime covers all calls.
|
||||
unsafe impl Send for WzpOboeRings {}
|
||||
unsafe impl Sync for WzpOboeRings {}
|
||||
|
||||
unsafe extern "C" {
|
||||
fn wzp_oboe_start(config: *const WzpOboeConfig, rings: *const WzpOboeRings) -> i32;
|
||||
fn wzp_oboe_stop();
|
||||
fn wzp_oboe_capture_latency_ms() -> f32;
|
||||
fn wzp_oboe_playout_latency_ms() -> f32;
|
||||
fn wzp_oboe_is_running() -> i32;
|
||||
}
|
||||
|
||||
// ─── SPSC ring buffer (shared with C++ via AtomicI32) ────────────────────
|
||||
|
||||
/// 20 ms @ 48 kHz mono = 960 samples.
|
||||
const FRAME_SAMPLES: usize = 960;
|
||||
/// ~160 ms headroom at 48 kHz.
|
||||
const RING_CAPACITY: usize = 7680;
|
||||
|
||||
struct RingBuffer {
|
||||
buf: Vec<i16>,
|
||||
capacity: usize,
|
||||
write_idx: AtomicI32,
|
||||
read_idx: AtomicI32,
|
||||
}
|
||||
|
||||
// SAFETY: SPSC with atomic read/write cursors; producer and consumer
|
||||
// are always on different threads.
|
||||
unsafe impl Send for RingBuffer {}
|
||||
unsafe impl Sync for RingBuffer {}
|
||||
|
||||
impl RingBuffer {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buf: vec![0i16; capacity],
|
||||
capacity,
|
||||
write_idx: AtomicI32::new(0),
|
||||
read_idx: AtomicI32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn available_read(&self) -> usize {
|
||||
let w = self.write_idx.load(Ordering::Acquire);
|
||||
let r = self.read_idx.load(Ordering::Relaxed);
|
||||
let avail = w - r;
|
||||
if avail < 0 {
|
||||
(avail + self.capacity as i32) as usize
|
||||
} else {
|
||||
avail as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn available_write(&self) -> usize {
|
||||
self.capacity - 1 - self.available_read()
|
||||
}
|
||||
|
||||
fn write(&self, data: &[i16]) -> usize {
|
||||
let count = data.len().min(self.available_write());
|
||||
if count == 0 {
|
||||
return 0;
|
||||
}
|
||||
let mut w = self.write_idx.load(Ordering::Relaxed) as usize;
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr() as *mut i16;
|
||||
for sample in &data[..count] {
|
||||
unsafe {
|
||||
*buf_ptr.add(w) = *sample;
|
||||
}
|
||||
w += 1;
|
||||
if w >= cap {
|
||||
w = 0;
|
||||
}
|
||||
}
|
||||
self.write_idx.store(w as i32, Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
fn read(&self, out: &mut [i16]) -> usize {
|
||||
let count = out.len().min(self.available_read());
|
||||
if count == 0 {
|
||||
return 0;
|
||||
}
|
||||
let mut r = self.read_idx.load(Ordering::Relaxed) as usize;
|
||||
let cap = self.capacity;
|
||||
let buf_ptr = self.buf.as_ptr();
|
||||
for slot in &mut out[..count] {
|
||||
unsafe {
|
||||
*slot = *buf_ptr.add(r);
|
||||
}
|
||||
r += 1;
|
||||
if r >= cap {
|
||||
r = 0;
|
||||
}
|
||||
}
|
||||
self.read_idx.store(r as i32, Ordering::Release);
|
||||
count
|
||||
}
|
||||
|
||||
fn buf_ptr(&self) -> *mut i16 {
|
||||
self.buf.as_ptr() as *mut i16
|
||||
}
|
||||
fn write_idx_ptr(&self) -> *mut AtomicI32 {
|
||||
&self.write_idx as *const AtomicI32 as *mut AtomicI32
|
||||
}
|
||||
fn read_idx_ptr(&self) -> *mut AtomicI32 {
|
||||
&self.read_idx as *const AtomicI32 as *mut AtomicI32
|
||||
}
|
||||
}
|
||||
|
||||
// ─── AudioBackend singleton ──────────────────────────────────────────────
|
||||
//
|
||||
// There is one global AudioBackend instance because Oboe's C++ side
|
||||
// holds its own singleton of the streams. The `Box::leak`'d statics own
|
||||
// the ring buffers for the lifetime of the process — dropping them while
|
||||
// Oboe is still running would cause use-after-free in the audio callback.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
struct AudioBackend {
|
||||
capture: RingBuffer,
|
||||
playout: RingBuffer,
|
||||
started: std::sync::Mutex<bool>,
|
||||
/// Per-write logging throttle counter for wzp_native_audio_write_playout.
|
||||
playout_write_log_count: std::sync::atomic::AtomicU64,
|
||||
/// Fix A (task #35): the playout ring's read_idx at the last
|
||||
/// check. If audio_write_playout observes read_idx hasn't
|
||||
/// advanced after N writes, the Oboe playout callback has
|
||||
/// stopped firing → restart the streams.
|
||||
playout_last_read_idx: std::sync::atomic::AtomicI32,
|
||||
/// Number of writes since the last read_idx advance.
|
||||
playout_stall_writes: std::sync::atomic::AtomicU32,
|
||||
}
|
||||
|
||||
static BACKEND: OnceLock<&'static AudioBackend> = OnceLock::new();
|
||||
|
||||
fn backend() -> &'static AudioBackend {
|
||||
BACKEND.get_or_init(|| {
|
||||
Box::leak(Box::new(AudioBackend {
|
||||
capture: RingBuffer::new(RING_CAPACITY),
|
||||
playout: RingBuffer::new(RING_CAPACITY),
|
||||
started: std::sync::Mutex::new(false),
|
||||
playout_write_log_count: std::sync::atomic::AtomicU64::new(0),
|
||||
playout_last_read_idx: std::sync::atomic::AtomicI32::new(0),
|
||||
playout_stall_writes: std::sync::atomic::AtomicU32::new(0),
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
// ─── C FFI for wzp-desktop ───────────────────────────────────────────────
|
||||
|
||||
/// Start the Oboe audio streams. Returns 0 on success, non-zero on error.
|
||||
/// Idempotent — calling while already running is a no-op that returns 0.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_start() -> i32 {
|
||||
audio_start_inner(false)
|
||||
}
|
||||
|
||||
/// Start Oboe in Bluetooth SCO mode — skips sample rate and input preset
|
||||
/// on capture so the system can route to the BT SCO device natively.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_start_bt() -> i32 {
|
||||
audio_start_inner(true)
|
||||
}
|
||||
|
||||
fn audio_start_inner(bt: bool) -> i32 {
|
||||
let b = backend();
|
||||
let mut started = match b.started.lock() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return -1,
|
||||
};
|
||||
if *started {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let config = WzpOboeConfig {
|
||||
sample_rate: 48_000,
|
||||
frames_per_burst: FRAME_SAMPLES as i32,
|
||||
channel_count: 1,
|
||||
bt_active: if bt { 1 } else { 0 },
|
||||
};
|
||||
let rings = WzpOboeRings {
|
||||
capture_buf: b.capture.buf_ptr(),
|
||||
capture_capacity: b.capture.capacity as i32,
|
||||
capture_write_idx: b.capture.write_idx_ptr(),
|
||||
capture_read_idx: b.capture.read_idx_ptr(),
|
||||
playout_buf: b.playout.buf_ptr(),
|
||||
playout_capacity: b.playout.capacity as i32,
|
||||
playout_write_idx: b.playout.write_idx_ptr(),
|
||||
playout_read_idx: b.playout.read_idx_ptr(),
|
||||
};
|
||||
let ret = unsafe { wzp_oboe_start(&config, &rings) };
|
||||
if ret != 0 {
|
||||
return ret;
|
||||
}
|
||||
*started = true;
|
||||
0
|
||||
}
|
||||
|
||||
/// Stop Oboe. Idempotent. Safe to call from any thread.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_stop() {
|
||||
let b = backend();
|
||||
if let Ok(mut started) = b.started.lock() {
|
||||
if *started {
|
||||
unsafe { wzp_oboe_stop() };
|
||||
*started = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of capture samples available to read without blocking.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_capture_available() -> usize {
|
||||
backend().capture.available_read()
|
||||
}
|
||||
|
||||
/// Read captured PCM samples from the capture ring. Returns the number
|
||||
/// of `i16` samples actually copied into `out` (may be less than
|
||||
/// `out_len` if the ring is empty).
|
||||
///
|
||||
/// # Safety
|
||||
/// `out` must be a valid pointer to `out_len` contiguous `i16` values.
|
||||
/// The caller must ensure no other thread writes to the same buffer
|
||||
/// concurrently. Passing a null pointer or zero length is safe (returns 0).
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_audio_read_capture(out: *mut i16, out_len: usize) -> usize {
|
||||
if out.is_null() || out_len == 0 {
|
||||
return 0;
|
||||
}
|
||||
let slice = unsafe { std::slice::from_raw_parts_mut(out, out_len) };
|
||||
backend().capture.read(slice)
|
||||
}
|
||||
|
||||
/// Write PCM samples into the playout ring. Returns the number of
|
||||
/// samples actually enqueued (may be less than `in_len` if the ring
|
||||
/// is nearly full — in practice the caller should pace to 20 ms
|
||||
/// frames and spin briefly if the ring is full).
|
||||
///
|
||||
/// # Safety
|
||||
/// `input` must be a valid pointer to `in_len` contiguous `i16` values
|
||||
/// that remain valid for the duration of the call. Passing a null pointer
|
||||
/// or zero length is safe (returns 0). The caller must not free or mutate
|
||||
/// the buffer while this function is executing.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn wzp_native_audio_write_playout(input: *const i16, in_len: usize) -> usize {
|
||||
if input.is_null() || in_len == 0 {
|
||||
return 0;
|
||||
}
|
||||
let slice = unsafe { std::slice::from_raw_parts(input, in_len) };
|
||||
let b = backend();
|
||||
|
||||
// Fix A (task #35): detect playout callback stall. If the
|
||||
// playout ring's read_idx hasn't advanced in 50+ writes
|
||||
// (~1 second at 50 writes/sec), the Oboe playout callback
|
||||
// has stopped firing → restart the streams. This is the
|
||||
// self-healing behavior that makes rejoin work: teardown +
|
||||
// rebuild clears whatever HAL state locked up the callback.
|
||||
let current_read_idx = b
|
||||
.playout
|
||||
.read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let last_read_idx = b
|
||||
.playout_last_read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if current_read_idx == last_read_idx {
|
||||
let stall = b
|
||||
.playout_stall_writes
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if stall >= 50 {
|
||||
// Callback hasn't drained anything in ~1 second.
|
||||
// Force a stream restart.
|
||||
unsafe {
|
||||
android_log(
|
||||
"playout STALL detected (50 writes, read_idx unchanged) — restarting Oboe streams",
|
||||
);
|
||||
}
|
||||
b.playout_stall_writes
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
// Release the started lock, stop, re-start.
|
||||
// This is the same logic as the Rust-side
|
||||
// audio_stop() + audio_start() but done inline
|
||||
// because we can't call the extern "C" fns
|
||||
// recursively. Just call the C++ side directly.
|
||||
{
|
||||
if let Ok(mut started) = b.started.lock() {
|
||||
if *started {
|
||||
unsafe { wzp_oboe_stop() };
|
||||
*started = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear the rings so the restart doesn't read stale data
|
||||
b.playout
|
||||
.write_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout
|
||||
.read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture
|
||||
.write_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.capture
|
||||
.read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
// Re-start (stall detector — always non-BT mode)
|
||||
let config = WzpOboeConfig {
|
||||
sample_rate: 48_000,
|
||||
frames_per_burst: FRAME_SAMPLES as i32,
|
||||
channel_count: 1,
|
||||
bt_active: 0,
|
||||
};
|
||||
let rings = WzpOboeRings {
|
||||
capture_buf: b.capture.buf_ptr(),
|
||||
capture_capacity: b.capture.capacity as i32,
|
||||
capture_write_idx: b.capture.write_idx_ptr(),
|
||||
capture_read_idx: b.capture.read_idx_ptr(),
|
||||
playout_buf: b.playout.buf_ptr(),
|
||||
playout_capacity: b.playout.capacity as i32,
|
||||
playout_write_idx: b.playout.write_idx_ptr(),
|
||||
playout_read_idx: b.playout.read_idx_ptr(),
|
||||
};
|
||||
let ret = unsafe { wzp_oboe_start(&config, &rings) };
|
||||
if ret == 0 {
|
||||
if let Ok(mut started) = b.started.lock() {
|
||||
*started = true;
|
||||
}
|
||||
unsafe {
|
||||
android_log("playout restart OK — Oboe streams rebuilt");
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
android_log(&format!("playout restart FAILED: {ret}"));
|
||||
}
|
||||
}
|
||||
b.playout_last_read_idx
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
return 0; // caller will retry on next frame
|
||||
}
|
||||
} else {
|
||||
// read_idx advanced — callback is alive, reset counter
|
||||
b.playout_stall_writes
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
b.playout_last_read_idx
|
||||
.store(current_read_idx, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let before_w = b
|
||||
.playout
|
||||
.write_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let before_r = b
|
||||
.playout
|
||||
.read_idx
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let written = b.playout.write(slice);
|
||||
// First few writes: log ring state + sample range so we can compare what
|
||||
// engine.rs hands us to what the C++ playout callback reads.
|
||||
let first_writes = b
|
||||
.playout_write_log_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if first_writes < 3 || first_writes % 50 == 0 {
|
||||
let (mut lo, mut hi, mut sumsq) = (i16::MAX, i16::MIN, 0i64);
|
||||
for &s in slice.iter() {
|
||||
if s < lo {
|
||||
lo = s;
|
||||
}
|
||||
if s > hi {
|
||||
hi = s;
|
||||
}
|
||||
sumsq += (s as i64) * (s as i64);
|
||||
}
|
||||
let rms = (sumsq as f64 / slice.len() as f64).sqrt() as i32;
|
||||
let avail_w_after = b.playout.available_write();
|
||||
let avail_r_after = b.playout.available_read();
|
||||
let msg = format!(
|
||||
"playout WRITE #{first_writes}: in_len={} written={} range=[{lo}..{hi}] rms={rms} before_w={before_w} before_r={before_r} avail_read_after={avail_r_after} avail_write_after={avail_w_after}",
|
||||
slice.len(),
|
||||
written
|
||||
);
|
||||
unsafe {
|
||||
android_log(msg.as_str());
|
||||
}
|
||||
}
|
||||
written
|
||||
}
|
||||
|
||||
// Minimal android logcat shim so we can print from the cdylib without pulling
|
||||
// in android_logger crate (which would add another dep that has to build with
|
||||
// cargo-ndk). Uses libc's __android_log_print via extern linkage.
|
||||
#[cfg(target_os = "android")]
|
||||
unsafe extern "C" {
|
||||
fn __android_log_write(prio: i32, tag: *const u8, text: *const u8) -> i32;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
unsafe fn android_log(msg: &str) {
|
||||
// ANDROID_LOG_INFO = 4. Tag and text must be NUL-terminated.
|
||||
let tag = b"wzp-native\0";
|
||||
let mut buf = Vec::with_capacity(msg.len() + 1);
|
||||
buf.extend_from_slice(msg.as_bytes());
|
||||
buf.push(0);
|
||||
unsafe {
|
||||
__android_log_write(4, tag.as_ptr(), buf.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "android"))]
|
||||
#[allow(dead_code)]
|
||||
unsafe fn android_log(_msg: &str) {}
|
||||
|
||||
/// Current capture latency reported by Oboe, in milliseconds. Returns
|
||||
/// NaN / 0.0 if the stream isn't running.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_capture_latency_ms() -> f32 {
|
||||
unsafe { wzp_oboe_capture_latency_ms() }
|
||||
}
|
||||
|
||||
/// Current playout latency reported by Oboe, in milliseconds.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_playout_latency_ms() -> f32 {
|
||||
unsafe { wzp_oboe_playout_latency_ms() }
|
||||
}
|
||||
|
||||
/// Non-zero if both Oboe streams are currently running.
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn wzp_native_audio_is_running() -> i32 {
|
||||
unsafe { wzp_oboe_is_running() }
|
||||
}
|
||||
@@ -20,3 +20,4 @@ tracing = "0.1"
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde_json = "1"
|
||||
bincode = "1"
|
||||
|
||||
@@ -7,10 +7,11 @@
|
||||
//! Control (GCC).
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::packet::QualityReport;
|
||||
use crate::QualityProfile;
|
||||
use crate::packet::QualityReport;
|
||||
|
||||
/// Network congestion state derived from delay and loss signals.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -158,6 +159,16 @@ pub struct BandwidthEstimator {
|
||||
loss_detector: LossBasedDetector,
|
||||
/// Last update timestamp.
|
||||
last_update: Option<Instant>,
|
||||
|
||||
// ── Transport-feedback BWE (T2.2) ──
|
||||
/// Congestion-window-derived bandwidth estimate in bits per second.
|
||||
cwnd_bps: AtomicU64,
|
||||
/// Peer REMB (Receiver Estimated Maximum Bitrate) in bits per second.
|
||||
peer_remb_bps: AtomicU64,
|
||||
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||
smoothed_bps: AtomicU64,
|
||||
/// Last time `smoothed_bps` was updated (UNIX epoch millis).
|
||||
last_smoothed_ms: AtomicU64,
|
||||
}
|
||||
|
||||
/// Multiplicative decrease factor applied on congestion (15% reduction).
|
||||
@@ -179,6 +190,10 @@ impl BandwidthEstimator {
|
||||
delay_detector: DelayBasedDetector::new(),
|
||||
loss_detector: LossBasedDetector::new(),
|
||||
last_update: None,
|
||||
cwnd_bps: AtomicU64::new(0),
|
||||
peer_remb_bps: AtomicU64::new(u64::MAX),
|
||||
smoothed_bps: AtomicU64::new(0),
|
||||
last_smoothed_ms: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,6 +265,64 @@ impl BandwidthEstimator {
|
||||
QualityProfile::CATASTROPHIC
|
||||
}
|
||||
}
|
||||
|
||||
// ── Transport-feedback BWE (T2.2) ──
|
||||
|
||||
/// Update from QUIC path stats.
|
||||
///
|
||||
/// Computes `cwnd_bps = cwnd_bytes * 8 / rtt_s` and feeds it into the
|
||||
/// smoothed estimate.
|
||||
pub fn update_from_path(&self, cwnd_bytes: u64, _bytes_in_flight: u64, rtt_ms: u32) {
|
||||
let rtt_s = rtt_ms.max(1) as f64 / 1000.0;
|
||||
let cwnd_bps = ((cwnd_bytes * 8) as f64 / rtt_s) as u64;
|
||||
self.cwnd_bps.store(cwnd_bps, Relaxed);
|
||||
self.update_smoothed(cwnd_bps);
|
||||
}
|
||||
|
||||
/// Update from a peer's `TransportFeedback` REMB value.
|
||||
pub fn update_from_peer(&self, fb_remb_bps: u32) {
|
||||
let remb = fb_remb_bps as u64;
|
||||
self.peer_remb_bps.store(remb, Relaxed);
|
||||
self.update_smoothed(remb);
|
||||
}
|
||||
|
||||
/// Target sending bitrate in bits per second.
|
||||
///
|
||||
/// Returns 90% of the minimum between the congestion-window estimate
|
||||
/// and the peer REMB estimate.
|
||||
pub fn target_send_bps(&self) -> u64 {
|
||||
let cwnd = self.cwnd_bps.load(Relaxed);
|
||||
let remb = self.peer_remb_bps.load(Relaxed);
|
||||
let m = cwnd.min(remb);
|
||||
(m as f64 * 0.9) as u64
|
||||
}
|
||||
|
||||
/// EWMA-smoothed bandwidth estimate in bits per second.
|
||||
pub fn smoothed_bps(&self) -> u64 {
|
||||
self.smoothed_bps.load(Relaxed)
|
||||
}
|
||||
|
||||
/// Apply EWMA smoothing with a 2-second half-life.
|
||||
fn update_smoothed(&self, new_bps: u64) {
|
||||
let now_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
let last_ms = self.last_smoothed_ms.load(Relaxed);
|
||||
let dt_ms = now_ms.saturating_sub(last_ms);
|
||||
|
||||
let current = self.smoothed_bps.load(Relaxed);
|
||||
let updated = if current == 0 || dt_ms == 0 {
|
||||
new_bps
|
||||
} else {
|
||||
let alpha = 1.0 - 0.5_f64.powf(dt_ms as f64 / 2000.0);
|
||||
let s = current as f64 * (1.0 - alpha) + new_bps as f64 * alpha;
|
||||
s as u64
|
||||
};
|
||||
|
||||
self.smoothed_bps.store(updated, Relaxed);
|
||||
self.last_smoothed_ms.store(now_ms, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -396,10 +469,7 @@ mod tests {
|
||||
|
||||
// Below 8 => CATASTROPHIC
|
||||
let bwe_cat = BandwidthEstimator::new(7.9, 2.0, 100.0);
|
||||
assert_eq!(
|
||||
bwe_cat.recommended_profile(),
|
||||
QualityProfile::CATASTROPHIC
|
||||
);
|
||||
assert_eq!(bwe_cat.recommended_profile(), QualityProfile::CATASTROPHIC);
|
||||
|
||||
// High bandwidth
|
||||
let bwe_high = BandwidthEstimator::new(80.0, 2.0, 100.0);
|
||||
@@ -413,7 +483,7 @@ mod tests {
|
||||
// Build a QualityReport with moderate loss and RTT.
|
||||
let report = QualityReport {
|
||||
loss_pct: (10.0_f32 / 100.0 * 255.0) as u8, // ~10% loss
|
||||
rtt_4ms: 25, // 100ms RTT
|
||||
rtt_4ms: 25, // 100ms RTT
|
||||
jitter_ms: 10,
|
||||
bitrate_cap_kbps: 200,
|
||||
};
|
||||
@@ -451,4 +521,46 @@ mod tests {
|
||||
}
|
||||
assert!(det.is_congested());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_send_bps_uses_min_of_cwnd_and_remb() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// cwnd_bps = 100_000, remb = 200_000 → min = 100_000 → 90%
|
||||
bwe.update_from_path(1250, 0, 100); // 1250*8 / 0.1 = 100_000
|
||||
bwe.update_from_peer(200_000);
|
||||
assert_eq!(bwe.target_send_bps(), 90_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_send_bps_with_zero_cwnd_uses_remb() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
// Default cwnd is 0, remb is u64::MAX (default).
|
||||
// 0.min(u64::MAX) = 0 → 90% = 0
|
||||
assert_eq!(bwe.target_send_bps(), 0);
|
||||
|
||||
bwe.update_from_peer(100_000);
|
||||
// cwnd still 0
|
||||
assert_eq!(bwe.target_send_bps(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn smoothed_bps_ewma_converges() {
|
||||
let bwe = BandwidthEstimator::new(50.0, 2.0, 100.0);
|
||||
bwe.update_from_path(1250, 0, 100); // 100_000 bps
|
||||
let s1 = bwe.smoothed_bps();
|
||||
assert_eq!(s1, 100_000);
|
||||
|
||||
// Immediately update with same value — dt ≈ 0, so should stay at 100_000
|
||||
bwe.update_from_path(1250, 0, 100);
|
||||
let s2 = bwe.smoothed_bps();
|
||||
assert_eq!(s2, 100_000);
|
||||
|
||||
// Sleep a bit so dt is non-zero, then update with a much higher value.
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
bwe.update_from_path(12500, 0, 100); // 1_000_000 bps
|
||||
let s3 = bwe.smoothed_bps();
|
||||
assert!(s3 > 100_000, "smoothed should increase toward 1M: {s3}");
|
||||
// With 100ms dt, alpha ≈ 0.03, so smoothed should be ~100k * 0.97 + 1M * 0.03 ≈ 127k
|
||||
assert!(s3 < 500_000, "smoothed should not jump too far: {s3}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Identifies the audio codec and bitrate configuration.
|
||||
///
|
||||
/// Encoded as 4 bits in the media packet header.
|
||||
/// Encoded as 4 bits in the v1 media packet header, and as a full 8-bit
|
||||
/// value in the v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum CodecId {
|
||||
@@ -18,6 +19,22 @@ pub enum CodecId {
|
||||
Codec2_1200 = 4,
|
||||
/// Comfort noise descriptor (silence suppression)
|
||||
ComfortNoise = 5,
|
||||
/// Opus at 32kbps (studio low)
|
||||
Opus32k = 6,
|
||||
/// Opus at 48kbps (studio)
|
||||
Opus48k = 7,
|
||||
/// Opus at 64kbps (studio high)
|
||||
Opus64k = 8,
|
||||
/// H.264 baseline profile (video).
|
||||
H264Baseline = 9,
|
||||
// Reserved for video codecs; implementations land in PRD-video-multicodec.
|
||||
// 10 => H264 main
|
||||
// 11 => H265 main
|
||||
// 13 => VP9
|
||||
/// AV1 main profile (video).
|
||||
Av1Main = 12,
|
||||
/// H.265 main profile (video).
|
||||
H265Main = 11,
|
||||
}
|
||||
|
||||
impl CodecId {
|
||||
@@ -27,30 +44,40 @@ impl CodecId {
|
||||
Self::Opus24k => 24_000,
|
||||
Self::Opus16k => 16_000,
|
||||
Self::Opus6k => 6_000,
|
||||
Self::Opus32k => 32_000,
|
||||
Self::Opus48k => 48_000,
|
||||
Self::Opus64k => 64_000,
|
||||
Self::Codec2_3200 => 3_200,
|
||||
Self::Codec2_1200 => 1_200,
|
||||
Self::ComfortNoise => 0,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 2_000_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Preferred frame duration in milliseconds.
|
||||
pub const fn frame_duration_ms(self) -> u8 {
|
||||
match self {
|
||||
Self::Opus24k => 20,
|
||||
Self::Opus16k => 20,
|
||||
Self::Opus24k | Self::Opus16k | Self::Opus32k | Self::Opus48k | Self::Opus64k => 20,
|
||||
Self::Opus6k => 40,
|
||||
Self::Codec2_3200 => 20,
|
||||
Self::Codec2_1200 => 40,
|
||||
Self::ComfortNoise => 20,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 33,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample rate expected by this codec.
|
||||
pub const fn sample_rate_hz(self) -> u32 {
|
||||
match self {
|
||||
Self::Opus24k | Self::Opus16k | Self::Opus6k => 48_000,
|
||||
Self::Opus24k
|
||||
| Self::Opus16k
|
||||
| Self::Opus6k
|
||||
| Self::Opus32k
|
||||
| Self::Opus48k
|
||||
| Self::Opus64k => 48_000,
|
||||
Self::Codec2_3200 | Self::Codec2_1200 => 8_000,
|
||||
Self::ComfortNoise => 48_000,
|
||||
Self::H264Baseline | Self::H265Main | Self::Av1Main => 48_000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +90,12 @@ impl CodecId {
|
||||
3 => Some(Self::Codec2_3200),
|
||||
4 => Some(Self::Codec2_1200),
|
||||
5 => Some(Self::ComfortNoise),
|
||||
6 => Some(Self::Opus32k),
|
||||
7 => Some(Self::Opus48k),
|
||||
8 => Some(Self::Opus64k),
|
||||
9 => Some(Self::H264Baseline),
|
||||
11 => Some(Self::H265Main),
|
||||
12 => Some(Self::Av1Main),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -71,6 +104,24 @@ impl CodecId {
|
||||
pub const fn to_wire(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
/// Returns true if this is a video codec variant.
|
||||
pub const fn is_video(self) -> bool {
|
||||
matches!(self, Self::H264Baseline | Self::H265Main | Self::Av1Main)
|
||||
}
|
||||
|
||||
/// Returns true if this is an Opus variant.
|
||||
pub const fn is_opus(self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Opus6k
|
||||
| Self::Opus16k
|
||||
| Self::Opus24k
|
||||
| Self::Opus32k
|
||||
| Self::Opus48k
|
||||
| Self::Opus64k
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes the complete quality configuration for a call session.
|
||||
@@ -84,6 +135,18 @@ pub struct QualityProfile {
|
||||
pub frame_duration_ms: u8,
|
||||
/// Number of source frames per FEC block.
|
||||
pub frames_per_block: u8,
|
||||
/// Bandwidth-allocation priority between audio and video.
|
||||
#[serde(default)]
|
||||
pub priority_mode: crate::PriorityMode,
|
||||
/// Target video bitrate in kbps (set by quality controller, not handshake).
|
||||
#[serde(default)]
|
||||
pub video_bitrate_kbps: Option<u32>,
|
||||
/// Target video resolution as (width, height).
|
||||
#[serde(default)]
|
||||
pub video_resolution: Option<(u16, u16)>,
|
||||
/// Target video frame rate.
|
||||
#[serde(default)]
|
||||
pub video_fps: Option<u8>,
|
||||
}
|
||||
|
||||
impl QualityProfile {
|
||||
@@ -93,6 +156,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.2,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Degraded conditions: Opus 6kbps, moderate FEC.
|
||||
@@ -101,6 +168,10 @@ impl QualityProfile {
|
||||
fec_ratio: 0.5,
|
||||
frame_duration_ms: 40,
|
||||
frames_per_block: 10,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Catastrophic conditions: Codec2 1.2kbps, heavy FEC.
|
||||
@@ -109,6 +180,46 @@ impl QualityProfile {
|
||||
fec_ratio: 1.0,
|
||||
frame_duration_ms: 40,
|
||||
frames_per_block: 8,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio low: Opus 32kbps, minimal FEC.
|
||||
pub const STUDIO_32K: Self = Self {
|
||||
codec: CodecId::Opus32k,
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio: Opus 48kbps, minimal FEC.
|
||||
pub const STUDIO_48K: Self = Self {
|
||||
codec: CodecId::Opus48k,
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Studio high: Opus 64kbps, minimal FEC.
|
||||
pub const STUDIO_64K: Self = Self {
|
||||
codec: CodecId::Opus64k,
|
||||
fec_ratio: 0.1,
|
||||
frame_duration_ms: 20,
|
||||
frames_per_block: 5,
|
||||
priority_mode: crate::PriorityMode::AudioFirst,
|
||||
video_bitrate_kbps: None,
|
||||
video_resolution: None,
|
||||
video_fps: None,
|
||||
};
|
||||
|
||||
/// Estimated total bandwidth in kbps including FEC overhead.
|
||||
@@ -117,3 +228,46 @@ impl QualityProfile {
|
||||
base * (1.0 + self.fec_ratio)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{CodecId, QualityProfile};
|
||||
use crate::PriorityMode;
|
||||
|
||||
#[test]
|
||||
fn codec_id_unknown_values_rejected() {
|
||||
for v in [10u8, 13].iter().copied().chain(14u8..=255) {
|
||||
assert!(CodecId::from_wire(v).is_none(), "v={v}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn h265_main_roundtrips() {
|
||||
assert_eq!(CodecId::H265Main.to_wire(), 11);
|
||||
assert_eq!(CodecId::from_wire(11), Some(CodecId::H265Main));
|
||||
assert!(CodecId::H265Main.is_video());
|
||||
assert_eq!(CodecId::H265Main.bitrate_bps(), 2_000_000);
|
||||
assert_eq!(CodecId::H265Main.frame_duration_ms(), 33);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn av1_main_roundtrips() {
|
||||
assert_eq!(CodecId::Av1Main.to_wire(), 12);
|
||||
assert_eq!(CodecId::from_wire(12), Some(CodecId::Av1Main));
|
||||
assert!(CodecId::Av1Main.is_video());
|
||||
assert_eq!(CodecId::Av1Main.bitrate_bps(), 2_000_000);
|
||||
assert_eq!(CodecId::Av1Main.frame_duration_ms(), 33);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quality_profile_backward_compat_old_json() {
|
||||
// Old JSON emitted before T5.1 has no priority_mode or video fields.
|
||||
let old_json =
|
||||
r#"{"codec":"Opus24k","fec_ratio":0.2,"frame_duration_ms":20,"frames_per_block":5}"#;
|
||||
let parsed: QualityProfile = serde_json::from_str(old_json).unwrap();
|
||||
assert_eq!(parsed.priority_mode, PriorityMode::AudioFirst);
|
||||
assert_eq!(parsed.video_bitrate_kbps, None);
|
||||
assert_eq!(parsed.video_resolution, None);
|
||||
assert_eq!(parsed.video_fps, None);
|
||||
}
|
||||
}
|
||||
|
||||
320
crates/wzp-proto/src/dred_tuner.rs
Normal file
320
crates/wzp-proto/src/dred_tuner.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Continuous DRED tuning from real-time network metrics.
|
||||
//!
|
||||
//! Instead of locking DRED duration to 3 discrete quality tiers (100/200/500 ms),
|
||||
//! `DredTuner` maps live path quality metrics to a continuous DRED duration and
|
||||
//! expected-loss hint, updated every N packets. This makes DRED reactive within
|
||||
//! ~200 ms instead of waiting for 3+ consecutive bad quality reports to trigger
|
||||
//! a full tier transition.
|
||||
//!
|
||||
//! The tuner also implements pre-emptive jitter-spike detection ("sawtooth"
|
||||
//! prediction): when jitter variance spikes >30% over a 200 ms window — typical
|
||||
//! of Starlink satellite handovers — it temporarily boosts DRED to the maximum
|
||||
//! allowed for the current codec before packets actually start dropping.
|
||||
//!
|
||||
//! See also: [`crate::quality`] for discrete tier classification that drives
|
||||
//! codec switching. DredTuner operates within a tier, adjusting DRED
|
||||
//! parameters continuously based on live network metrics.
|
||||
|
||||
use crate::CodecId;
|
||||
|
||||
/// Output of a single tuning cycle.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct DredTuning {
|
||||
/// DRED duration in 10 ms frame units (0–104). Passed directly to
|
||||
/// `OpusEncoder::set_dred_duration()`.
|
||||
pub dred_frames: u8,
|
||||
/// Expected packet loss percentage (0–100). Passed to
|
||||
/// `OpusEncoder::set_expected_loss()`. Floored at 15% by the encoder
|
||||
/// itself, but we pass the real value so the encoder can override upward.
|
||||
pub expected_loss_pct: u8,
|
||||
}
|
||||
|
||||
/// Minimum DRED frames for any Opus codec (matches DRED_LOSS_FLOOR_PCT logic:
|
||||
/// at 15% loss, libopus 1.5 emits ~95 ms of DRED, which needs at least 10
|
||||
/// frames configured to be useful).
|
||||
const MIN_DRED_FRAMES: u8 = 5;
|
||||
|
||||
/// Maximum DRED frames libopus supports (104 × 10 ms = 1040 ms).
|
||||
const MAX_DRED_FRAMES: u8 = 104;
|
||||
|
||||
/// Jitter variance spike ratio that triggers pre-emptive DRED boost.
|
||||
const JITTER_SPIKE_RATIO: f32 = 1.3;
|
||||
|
||||
/// How many tuning cycles a jitter-spike boost persists (at 25 packets/cycle
|
||||
/// and 20 ms/packet, 10 cycles ≈ 5 seconds).
|
||||
const SPIKE_BOOST_COOLDOWN_CYCLES: u32 = 10;
|
||||
|
||||
/// Maps codec tier to its baseline DRED frames (used when network is healthy).
|
||||
fn baseline_dred_frames(codec: CodecId) -> u8 {
|
||||
match codec {
|
||||
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 10, // 100 ms
|
||||
CodecId::Opus16k | CodecId::Opus24k => 20, // 200 ms
|
||||
CodecId::Opus6k => 50, // 500 ms
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps codec tier to its maximum allowed DRED frames under spike/bad conditions.
|
||||
fn max_dred_frames_for(codec: CodecId) -> u8 {
|
||||
match codec {
|
||||
// Studio: cap at 300 ms (don't waste bitrate on good links)
|
||||
CodecId::Opus32k | CodecId::Opus48k | CodecId::Opus64k => 30,
|
||||
// Normal: cap at 500 ms
|
||||
CodecId::Opus16k | CodecId::Opus24k => 50,
|
||||
// Degraded: allow full 1040 ms
|
||||
CodecId::Opus6k => MAX_DRED_FRAMES,
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Continuous DRED tuner driven by network path metrics.
|
||||
pub struct DredTuner {
|
||||
/// Current codec (determines baseline and ceiling).
|
||||
codec: CodecId,
|
||||
/// Last computed tuning output.
|
||||
last_tuning: DredTuning,
|
||||
/// EWMA-smoothed jitter for spike detection (in ms).
|
||||
jitter_ewma: f32,
|
||||
/// Remaining cooldown cycles for a jitter-spike boost.
|
||||
spike_cooldown: u32,
|
||||
/// Whether the tuner has received at least one observation.
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl DredTuner {
|
||||
/// Create a new tuner for the given codec.
|
||||
pub fn new(codec: CodecId) -> Self {
|
||||
let baseline = baseline_dred_frames(codec);
|
||||
Self {
|
||||
codec,
|
||||
last_tuning: DredTuning {
|
||||
dred_frames: baseline,
|
||||
expected_loss_pct: 15, // match DRED_LOSS_FLOOR_PCT
|
||||
},
|
||||
jitter_ewma: 0.0,
|
||||
spike_cooldown: 0,
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the active codec (e.g. on tier transition). Resets spike state.
|
||||
pub fn set_codec(&mut self, codec: CodecId) {
|
||||
self.codec = codec;
|
||||
self.spike_cooldown = 0;
|
||||
}
|
||||
|
||||
/// Feed network metrics and compute new DRED parameters.
|
||||
///
|
||||
/// Call this every tuning cycle (e.g. every 25 packets ≈ 500 ms at 20 ms
|
||||
/// frame duration).
|
||||
///
|
||||
/// - `loss_pct`: observed packet loss (0.0–100.0)
|
||||
/// - `rtt_ms`: smoothed round-trip time
|
||||
/// - `jitter_ms`: current jitter estimate (RTT variance)
|
||||
///
|
||||
/// Returns `Some(tuning)` if the output changed, `None` if unchanged.
|
||||
pub fn update(&mut self, loss_pct: f32, rtt_ms: u32, jitter_ms: u32) -> Option<DredTuning> {
|
||||
if !self.codec.is_opus() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let baseline = baseline_dred_frames(self.codec);
|
||||
let ceiling = max_dred_frames_for(self.codec);
|
||||
|
||||
// --- Jitter spike detection ---
|
||||
let jitter_f = jitter_ms as f32;
|
||||
if !self.initialized {
|
||||
self.jitter_ewma = jitter_f;
|
||||
self.initialized = true;
|
||||
} else {
|
||||
// Fast-up (alpha=0.3), slow-down (alpha=0.05) asymmetric EWMA
|
||||
let alpha = if jitter_f > self.jitter_ewma {
|
||||
0.3
|
||||
} else {
|
||||
0.05
|
||||
};
|
||||
self.jitter_ewma = alpha * jitter_f + (1.0 - alpha) * self.jitter_ewma;
|
||||
}
|
||||
|
||||
// Detect spike: instantaneous jitter > EWMA × 1.3
|
||||
if self.jitter_ewma > 1.0 && jitter_f > self.jitter_ewma * JITTER_SPIKE_RATIO {
|
||||
self.spike_cooldown = SPIKE_BOOST_COOLDOWN_CYCLES;
|
||||
}
|
||||
|
||||
// Decrement cooldown
|
||||
if self.spike_cooldown > 0 {
|
||||
self.spike_cooldown -= 1;
|
||||
}
|
||||
|
||||
// --- Compute DRED frames ---
|
||||
let dred_frames = if self.spike_cooldown > 0 {
|
||||
// During spike boost: jump to ceiling
|
||||
ceiling
|
||||
} else {
|
||||
// Continuous mapping: scale linearly between baseline and ceiling
|
||||
// based on loss percentage.
|
||||
// 0% loss → baseline
|
||||
// 40% loss → ceiling
|
||||
let loss_clamped = loss_pct.clamp(0.0, 40.0);
|
||||
let t = loss_clamped / 40.0;
|
||||
let raw = baseline as f32 + t * (ceiling - baseline) as f32;
|
||||
(raw as u8).clamp(MIN_DRED_FRAMES, ceiling)
|
||||
};
|
||||
|
||||
// --- Compute expected loss hint ---
|
||||
// Pass the real loss so the encoder can clamp at its own floor (15%).
|
||||
// For RTT-driven boost: high RTT suggests impending loss, so add a
|
||||
// phantom loss contribution to keep DRED emitting generously.
|
||||
let rtt_loss_phantom = if rtt_ms > 200 {
|
||||
((rtt_ms - 200) as f32 / 40.0).min(15.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let expected_loss = (loss_pct + rtt_loss_phantom).clamp(0.0, 100.0) as u8;
|
||||
|
||||
let tuning = DredTuning {
|
||||
dred_frames,
|
||||
expected_loss_pct: expected_loss,
|
||||
};
|
||||
|
||||
if tuning != self.last_tuning {
|
||||
self.last_tuning = tuning;
|
||||
Some(tuning)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the last computed tuning without updating.
|
||||
pub fn current(&self) -> DredTuning {
|
||||
self.last_tuning
|
||||
}
|
||||
|
||||
/// Whether a jitter-spike boost is currently active.
|
||||
pub fn spike_boost_active(&self) -> bool {
|
||||
self.spike_cooldown > 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn baseline_for_opus24k() {
|
||||
let tuner = DredTuner::new(CodecId::Opus24k);
|
||||
assert_eq!(tuner.current().dred_frames, 20); // 200 ms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn baseline_for_opus6k() {
|
||||
let tuner = DredTuner::new(CodecId::Opus6k);
|
||||
assert_eq!(tuner.current().dred_frames, 50); // 500 ms
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn codec2_returns_none() {
|
||||
let mut tuner = DredTuner::new(CodecId::Codec2_1200);
|
||||
assert!(tuner.update(10.0, 100, 20).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scales_with_loss() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// 0% loss → baseline (20 frames)
|
||||
tuner.update(0.0, 50, 5);
|
||||
assert_eq!(tuner.current().dred_frames, 20);
|
||||
|
||||
// 20% loss → midpoint between 20 and 50 = 35
|
||||
tuner.update(20.0, 50, 5);
|
||||
assert_eq!(tuner.current().dred_frames, 35);
|
||||
|
||||
// 40%+ loss → ceiling (50 frames)
|
||||
tuner.update(40.0, 50, 5);
|
||||
assert_eq!(tuner.current().dred_frames, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jitter_spike_triggers_boost() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// Establish baseline jitter
|
||||
for _ in 0..20 {
|
||||
tuner.update(0.0, 50, 10);
|
||||
}
|
||||
assert!(!tuner.spike_boost_active());
|
||||
|
||||
// Spike: jitter jumps to 50 ms (5x the EWMA of ~10)
|
||||
tuner.update(0.0, 50, 50);
|
||||
assert!(tuner.spike_boost_active());
|
||||
// Should be at ceiling (50 frames = 500 ms for Opus24k)
|
||||
assert_eq!(tuner.current().dred_frames, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spike_cooldown_decays() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// Establish baseline then spike
|
||||
for _ in 0..20 {
|
||||
tuner.update(0.0, 50, 10);
|
||||
}
|
||||
tuner.update(0.0, 50, 50);
|
||||
assert!(tuner.spike_boost_active());
|
||||
|
||||
// Run through cooldown
|
||||
for _ in 0..SPIKE_BOOST_COOLDOWN_CYCLES {
|
||||
tuner.update(0.0, 50, 10);
|
||||
}
|
||||
assert!(!tuner.spike_boost_active());
|
||||
// Should return to baseline
|
||||
assert_eq!(tuner.current().dred_frames, 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rtt_phantom_loss() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// High RTT (400ms) with 0% real loss
|
||||
tuner.update(0.0, 400, 10);
|
||||
// Phantom loss = (400-200)/40 = 5
|
||||
assert_eq!(tuner.current().expected_loss_pct, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_codec_resets_spike() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// Trigger spike
|
||||
for _ in 0..20 {
|
||||
tuner.update(0.0, 50, 10);
|
||||
}
|
||||
tuner.update(0.0, 50, 50);
|
||||
assert!(tuner.spike_boost_active());
|
||||
|
||||
// Switch codec — spike should reset
|
||||
tuner.set_codec(CodecId::Opus6k);
|
||||
assert!(!tuner.spike_boost_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opus6k_reaches_max_1040ms() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus6k);
|
||||
|
||||
// High loss → should reach 104 frames (1040 ms)
|
||||
tuner.update(40.0, 50, 5);
|
||||
assert_eq!(tuner.current().dred_frames, MAX_DRED_FRAMES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_when_unchanged() {
|
||||
let mut tuner = DredTuner::new(CodecId::Opus24k);
|
||||
|
||||
// First update always returns Some (initial → computed)
|
||||
let first = tuner.update(0.0, 50, 5);
|
||||
// Same inputs → None
|
||||
let second = tuner.update(0.0, 50, 5);
|
||||
assert!(first.is_some() || second.is_none());
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ pub enum CryptoError {
|
||||
#[error("rekey failed: {0}")]
|
||||
RekeyFailed(String),
|
||||
#[error("anti-replay: duplicate or old packet (seq={seq})")]
|
||||
ReplayDetected { seq: u16 },
|
||||
ReplayDetected { seq: u32 },
|
||||
#[error("internal crypto error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
@@ -53,6 +53,15 @@ pub enum TransportError {
|
||||
Timeout { ms: u64 },
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
/// Parsed wire bytes successfully but the payload didn't
|
||||
/// deserialize into a known `SignalMessage` variant. Usually
|
||||
/// means the peer is running a newer build with a variant we
|
||||
/// don't know yet. Callers should **log and continue** rather
|
||||
/// than tearing down the connection, so that forward-compat
|
||||
/// additions to `SignalMessage` don't silently kill old
|
||||
/// clients/relays.
|
||||
#[error("signal deserialize: {0}")]
|
||||
Deserialize(String),
|
||||
#[error("internal transport error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
@@ -81,9 +81,7 @@ impl AdaptivePlayoutDelay {
|
||||
let jitter = (actual_delta - expected_delta).abs();
|
||||
|
||||
// Spike detection: check before EMA update
|
||||
if self.jitter_ema > 0.0
|
||||
&& jitter > self.jitter_ema * self.spike_threshold_multiplier
|
||||
{
|
||||
if self.jitter_ema > 0.0 && jitter > self.jitter_ema * self.spike_threshold_multiplier {
|
||||
self.spike_detected_at = Some(Instant::now());
|
||||
}
|
||||
|
||||
@@ -107,10 +105,8 @@ impl AdaptivePlayoutDelay {
|
||||
self.target_delay = self.max_delay;
|
||||
} else {
|
||||
// Convert jitter estimate to target delay in packets
|
||||
let raw_target =
|
||||
(self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||
self.target_delay =
|
||||
(raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||
let raw_target = (self.jitter_ema / FRAME_DURATION_MS).ceil() + self.safety_margin;
|
||||
self.target_delay = (raw_target as usize).clamp(self.min_delay, self.max_delay);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,9 +158,9 @@ impl AdaptivePlayoutDelay {
|
||||
/// Manages packet reordering, gap detection, and signals when PLC is needed.
|
||||
pub struct JitterBuffer {
|
||||
/// Packets waiting to be consumed, ordered by sequence number.
|
||||
buffer: BTreeMap<u16, MediaPacket>,
|
||||
buffer: BTreeMap<u32, MediaPacket>,
|
||||
/// Next sequence number expected for playout.
|
||||
next_playout_seq: u16,
|
||||
next_playout_seq: u32,
|
||||
/// Maximum buffer depth in number of packets.
|
||||
max_depth: usize,
|
||||
/// Target buffer depth (adaptive, based on jitter).
|
||||
@@ -204,7 +200,7 @@ pub enum PlayoutResult {
|
||||
/// A packet is available for playout.
|
||||
Packet(MediaPacket),
|
||||
/// The expected packet is missing — decoder should generate PLC.
|
||||
Missing { seq: u16 },
|
||||
Missing { seq: u32 },
|
||||
/// Buffer is empty or not yet filled to target depth.
|
||||
NotReady,
|
||||
}
|
||||
@@ -273,10 +269,30 @@ impl JitterBuffer {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if packet is too old (already played out)
|
||||
// Check if packet is too old (already played out).
|
||||
// A backward jump of >100 seq (~2s at 50fps) indicates a new sender in a
|
||||
// federation room — reset instead of dropping.
|
||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||
self.stats.packets_late += 1;
|
||||
return;
|
||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||
tracing::warn!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
backward_distance,
|
||||
"jitter: backward seq detected"
|
||||
);
|
||||
if backward_distance > 100 {
|
||||
tracing::info!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
"jitter: RESET — new sender detected"
|
||||
);
|
||||
self.buffer.clear();
|
||||
self.next_playout_seq = seq;
|
||||
self.stats.packets_late = 0;
|
||||
} else {
|
||||
self.stats.packets_late += 1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If we haven't started playout yet, adjust next_playout_seq to earliest known
|
||||
@@ -412,10 +428,30 @@ impl JitterBuffer {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if packet is too old (already played out)
|
||||
// Check if packet is too old (already played out).
|
||||
// A backward jump of >100 seq (~2s at 50fps) indicates a new sender in a
|
||||
// federation room — reset instead of dropping.
|
||||
if self.stats.packets_played > 0 && seq_before(seq, self.next_playout_seq) {
|
||||
self.stats.packets_late += 1;
|
||||
return;
|
||||
let backward_distance = self.next_playout_seq.wrapping_sub(seq);
|
||||
tracing::warn!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
backward_distance,
|
||||
"jitter: backward seq detected"
|
||||
);
|
||||
if backward_distance > 100 {
|
||||
tracing::info!(
|
||||
seq,
|
||||
next = self.next_playout_seq,
|
||||
"jitter: RESET — new sender detected"
|
||||
);
|
||||
self.buffer.clear();
|
||||
self.next_playout_seq = seq;
|
||||
self.stats.packets_late = 0;
|
||||
} else {
|
||||
self.stats.packets_late += 1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If we haven't started playout yet, adjust next_playout_seq to earliest known
|
||||
@@ -467,7 +503,7 @@ impl JitterBuffer {
|
||||
|
||||
/// Sequence number comparison with wrapping (RFC 1982 serial number arithmetic).
|
||||
/// Returns true if `a` comes before `b` in sequence space.
|
||||
fn seq_before(a: u16, b: u16) -> bool {
|
||||
fn seq_before(a: u32, b: u32) -> bool {
|
||||
let diff = b.wrapping_sub(a);
|
||||
diff > 0 && diff < 0x8000
|
||||
}
|
||||
@@ -475,24 +511,23 @@ fn seq_before(a: u16, b: u16) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::CodecId;
|
||||
use crate::MediaType;
|
||||
use crate::packet::{MediaHeader, MediaPacket};
|
||||
use bytes::Bytes;
|
||||
use crate::CodecId;
|
||||
|
||||
fn make_packet(seq: u16) -> MediaPacket {
|
||||
fn make_packet(seq: u32) -> MediaPacket {
|
||||
MediaPacket {
|
||||
header: MediaHeader {
|
||||
version: 0,
|
||||
is_repair: false,
|
||||
version: 2,
|
||||
flags: 0,
|
||||
media_type: MediaType::Audio,
|
||||
codec_id: CodecId::Opus24k,
|
||||
has_quality_report: false,
|
||||
fec_ratio_encoded: 0,
|
||||
stream_id: 0,
|
||||
fec_ratio: 0,
|
||||
seq,
|
||||
timestamp: seq as u32 * 20,
|
||||
timestamp: seq * 20,
|
||||
fec_block: 0,
|
||||
fec_symbol: 0,
|
||||
reserved: 0,
|
||||
csrc_count: 0,
|
||||
},
|
||||
payload: Bytes::from(vec![0u8; 60]),
|
||||
quality_report: None,
|
||||
@@ -576,7 +611,7 @@ mod tests {
|
||||
fn seq_before_wrapping() {
|
||||
assert!(seq_before(0, 1));
|
||||
assert!(seq_before(65534, 65535));
|
||||
assert!(seq_before(65535, 0)); // wrap
|
||||
assert!(seq_before(u32::MAX, 0)); // wrap
|
||||
assert!(!seq_before(1, 0));
|
||||
assert!(!seq_before(5, 5)); // equal
|
||||
}
|
||||
@@ -778,7 +813,7 @@ mod tests {
|
||||
let mut jb = JitterBuffer::new_adaptive(3, 50);
|
||||
|
||||
// Push packets with consistent timing
|
||||
for i in 0u16..20 {
|
||||
for i in 0u32..20 {
|
||||
let pkt = make_packet(i);
|
||||
let arrival_ms = i as u64 * 20;
|
||||
jb.push_with_arrival(pkt, arrival_ms);
|
||||
|
||||
@@ -14,21 +14,28 @@
|
||||
|
||||
pub mod bandwidth;
|
||||
pub mod codec_id;
|
||||
pub mod dred_tuner;
|
||||
pub mod error;
|
||||
pub mod jitter;
|
||||
pub mod media_type;
|
||||
pub mod packet;
|
||||
pub mod priority_mode;
|
||||
pub mod quality;
|
||||
pub mod session;
|
||||
pub mod traits;
|
||||
|
||||
// Re-export key types at crate root for convenience.
|
||||
pub use codec_id::{CodecId, QualityProfile};
|
||||
pub use error::*;
|
||||
pub use packet::{
|
||||
HangupReason, MediaHeader, MediaPacket, MiniFrameContext, MiniHeader, QualityReport,
|
||||
RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, FRAME_TYPE_FULL, FRAME_TYPE_MINI,
|
||||
};
|
||||
pub use bandwidth::{BandwidthEstimator, CongestionState};
|
||||
pub use codec_id::{CodecId, QualityProfile};
|
||||
pub use dred_tuner::{DredTuner, DredTuning};
|
||||
pub use error::*;
|
||||
pub use media_type::MediaType;
|
||||
pub use packet::{
|
||||
CallAcceptMode, FRAME_TYPE_FULL, FRAME_TYPE_MINI, HangupReason, MediaHeader, MediaHeaderV2,
|
||||
MediaPacket, MiniFrameContext, MiniFrameContextV2, MiniHeader, MiniHeaderV2, PresenceUser,
|
||||
QualityReport, RoomParticipant, SignalMessage, TrunkEntry, TrunkFrame, default_signal_version,
|
||||
};
|
||||
pub use priority_mode::PriorityMode;
|
||||
pub use quality::{AdaptiveQualityController, NetworkContext, Tier};
|
||||
pub use session::{Session, SessionEvent, SessionState};
|
||||
pub use traits::*;
|
||||
|
||||
57
crates/wzp-proto/src/media_type.rs
Normal file
57
crates/wzp-proto/src/media_type.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Media stream type carried in a v2 [`MediaHeaderV2`](crate::MediaHeaderV2).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum MediaType {
|
||||
/// Encoded speech / music (Opus, Codec2, ComfortNoise).
|
||||
Audio = 0,
|
||||
/// Encoded video access unit (H.264, H.265, AV1; PRD-video-multicodec).
|
||||
Video = 1,
|
||||
/// Opaque payload not interpreted by the relay (reserved).
|
||||
Data = 2,
|
||||
/// In-band control message carried on the media plane (reserved).
|
||||
Control = 3,
|
||||
}
|
||||
|
||||
impl MediaType {
|
||||
/// Encode to the wire byte representation (`self as u8`).
|
||||
pub const fn to_wire(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
/// Decode from a wire byte. Returns `None` for values outside 0..=3.
|
||||
pub const fn from_wire(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Audio),
|
||||
1 => Some(Self::Video),
|
||||
2 => Some(Self::Data),
|
||||
3 => Some(Self::Control),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn media_type_roundtrip() {
|
||||
for mt in [
|
||||
MediaType::Audio,
|
||||
MediaType::Video,
|
||||
MediaType::Data,
|
||||
MediaType::Control,
|
||||
] {
|
||||
assert_eq!(MediaType::from_wire(mt.to_wire()), Some(mt));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn media_type_unknown_rejected() {
|
||||
for v in 4u8..=255 {
|
||||
assert!(MediaType::from_wire(v).is_none(), "v={v}");
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user