T1.6: Protocol version negotiation in handshake

This commit is contained in:
Siavash Sameni
2026-05-11 15:52:18 +04:00
parent 5cdb50160a
commit 6f81487778
18 changed files with 499 additions and 84 deletions

View File

@@ -22,7 +22,8 @@ use wzp_crypto::{KeyExchange, WarzoneKeyExchange};
use wzp_fec::{RaptorQFecDecoder, RaptorQFecEncoder};
use wzp_proto::{
AdaptiveQualityController, AudioDecoder, AudioEncoder, CodecId, FecDecoder, FecEncoder,
MediaHeader, MediaPacket, MediaTransport, QualityController, QualityProfile, SignalMessage,
MediaHeader, MediaPacket, MediaTransport, MediaType, QualityController, QualityProfile,
SignalMessage,
};
use crate::audio_ring::AudioRing;
@@ -533,6 +534,8 @@ async fn run_call(
QualityProfile::CATASTROPHIC,
],
alias: alias.map(|s| s.to_string()),
protocol_version: 2,
supported_versions: vec![2],
};
transport.send_signal(&offer).await?;
info!("CallOffer sent, waiting for CallAnswer...");
@@ -603,7 +606,7 @@ async fn run_call(
stats.auto_mode = auto_profile;
}
let seq = AtomicU16::new(0);
let seq = AtomicU32::new(0);
let ts = AtomicU32::new(0);
let transport_recv = transport.clone();
@@ -729,17 +732,15 @@ async fn run_call(
let source_pkt = MediaPacket {
header: MediaHeader {
version: 0,
is_repair: false,
version: MediaHeader::VERSION,
flags: 0,
media_type: MediaType::Audio,
codec_id: current_profile.codec,
has_quality_report: false,
fec_ratio_encoded: hdr_fec_ratio,
stream_id: 0,
fec_ratio: hdr_fec_ratio,
seq: s,
timestamp: t,
fec_block: hdr_fec_block,
fec_symbol: hdr_fec_symbol,
reserved: 0,
csrc_count: 0,
fec_block: ((hdr_fec_symbol as u16) << 8) | (hdr_fec_block as u16),
},
payload: Bytes::copy_from_slice(encoded),
quality_report: None,
@@ -783,19 +784,17 @@ async fn run_call(
let rs = seq.fetch_add(1, Ordering::Relaxed);
let repair_pkt = MediaPacket {
header: MediaHeader {
version: 0,
is_repair: true,
version: MediaHeader::VERSION,
flags: MediaHeader::FLAG_REPAIR,
media_type: MediaType::Audio,
codec_id: current_profile.codec,
has_quality_report: false,
fec_ratio_encoded: MediaHeader::encode_fec_ratio(
stream_id: 0,
fec_ratio: MediaHeader::encode_fec_ratio(
current_profile.fec_ratio,
),
seq: rs,
timestamp: t,
fec_block: block_id,
fec_symbol: sym_idx,
reserved: 0,
csrc_count: 0,
fec_block: ((sym_idx as u16) << 8) | (block_id as u16),
},
payload: Bytes::from(repair_data),
quality_report: None,
@@ -883,8 +882,8 @@ async fn run_call(
let mut dred_decoder = DredDecoderHandle::new().expect("opus_dred_decoder_create failed");
let mut dred_parse_scratch = DredState::new().expect("opus_dred_alloc failed (scratch)");
let mut last_good_dred = DredState::new().expect("opus_dred_alloc failed (good state)");
let mut last_good_dred_seq: Option<u16> = None;
let mut expected_seq: Option<u16> = None;
let mut last_good_dred_seq: Option<u32> = None;
let mut expected_seq: Option<u32> = None;
let mut dred_reconstructions: u64 = 0;
let mut classical_plc_invocations: u64 = 0;
@@ -905,7 +904,7 @@ async fn run_call(
warn!(
recv_gap_ms,
seq = pkt.header.seq,
is_repair = pkt.header.is_repair,
is_repair = pkt.header.is_repair(),
"large recv gap — possible network stall"
);
}
@@ -946,9 +945,9 @@ async fn run_call(
}
}
let is_repair = pkt.header.is_repair;
let pkt_block = pkt.header.fec_block;
let pkt_symbol = pkt.header.fec_symbol;
let is_repair = pkt.header.is_repair();
let pkt_block = pkt.header.fec_block as u8;
let pkt_symbol = (pkt.header.fec_block >> 8) as u8;
let pkt_is_opus = pkt.header.codec_id.is_opus();
// Phase 2: Opus packets bypass RaptorQ entirely — DRED
@@ -1024,7 +1023,7 @@ async fn run_call(
}
// Detect and fill gap from last-expected to this packet.
const MAX_GAP_FRAMES: u16 = 16;
const MAX_GAP_FRAMES: u32 = 16;
if let Some(expected) = expected_seq {
let gap = pkt.header.seq.wrapping_sub(expected);
if gap > 0 && gap <= MAX_GAP_FRAMES {

View File

@@ -134,11 +134,11 @@ impl Pipeline {
pub fn feed_packet(&mut self, packet: MediaPacket) {
// Feed FEC symbols if present
let header = &packet.header;
if header.fec_block != 0 || header.fec_symbol != 0 {
let is_repair = header.is_repair;
if header.fec_block != 0 {
let is_repair = header.is_repair();
if let Err(e) = self.fec_decoder.add_symbol(
header.fec_block,
header.fec_symbol,
header.fec_block as u8,
(header.fec_block >> 8) as u8,
is_repair,
&packet.payload,
) {

View File

@@ -724,7 +724,7 @@ async fn run_live(transport: Arc<wzp_transport::QuinnTransport>) -> anyhow::Resu
loop {
match recv_transport.recv_media().await {
Ok(Some(pkt)) => {
let is_repair = pkt.header.is_repair;
let is_repair = pkt.header.is_repair();
decoder.ingest(pkt);
// Only decode for source packets (1 source = 1 audio frame).
// Repair packets feed the FEC decoder but don't produce audio.

View File

@@ -156,6 +156,8 @@ mod tests {
signature: vec![3u8; 64],
supported_profiles: vec![QualityProfile::GOOD],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
};
let encoded = encode_call_payload(&signal, Some("relay.example.com:4433"), Some("myroom"));
@@ -174,6 +176,8 @@ mod tests {
signature: vec![],
supported_profiles: vec![],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
};
assert!(matches!(signal_to_call_type(&offer), CallSignalType::Offer));

View File

@@ -4,7 +4,51 @@
//! send `CallOffer` → recv `CallAnswer` → derive shared `CryptoSession`.
use wzp_crypto::{CryptoSession, KeyExchange, WarzoneKeyExchange};
use wzp_proto::{MediaTransport, QualityProfile, SignalMessage};
use wzp_proto::{HangupReason, MediaTransport, QualityProfile, SignalMessage};
/// Errors that can occur during the client-side cryptographic handshake.
#[derive(Debug)]
pub enum HandshakeError {
ConnectionClosed,
ProtocolVersionMismatch { server_supported: Vec<u8> },
UnexpectedSignal(&'static str),
SignatureVerificationFailed,
KeyDerivation(String),
Transport(wzp_proto::TransportError),
}
impl std::fmt::Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConnectionClosed => write!(f, "connection closed before receiving CallAnswer"),
Self::ProtocolVersionMismatch { server_supported } => {
write!(
f,
"protocol version mismatch: server supports {server_supported:?}"
)
}
Self::UnexpectedSignal(expected) => write!(f, "expected CallAnswer, got {expected}"),
Self::SignatureVerificationFailed => write!(f, "callee signature verification failed"),
Self::KeyDerivation(msg) => write!(f, "key derivation failed: {msg}"),
Self::Transport(e) => write!(f, "transport error: {e}"),
}
}
}
impl std::error::Error for HandshakeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Transport(e) => Some(e),
_ => None,
}
}
}
impl From<wzp_proto::TransportError> for HandshakeError {
fn from(e: wzp_proto::TransportError) -> Self {
Self::Transport(e)
}
}
/// Perform the client (caller) side of the cryptographic handshake.
///
@@ -18,7 +62,7 @@ pub async fn perform_handshake(
transport: &dyn MediaTransport,
seed: &[u8; 32],
alias: Option<&str>,
) -> Result<Box<dyn CryptoSession>, anyhow::Error> {
) -> Result<Box<dyn CryptoSession>, HandshakeError> {
// 1. Create key exchange from identity seed
let mut kx = WarzoneKeyExchange::from_identity_seed(seed);
let identity_pub = kx.identity_public_key();
@@ -46,14 +90,20 @@ pub async fn perform_handshake(
QualityProfile::CATASTROPHIC,
],
alias: alias.map(|s| s.to_string()),
protocol_version: 2,
supported_versions: vec![2],
};
transport.send_signal(&offer).await?;
transport
.send_signal(&offer)
.await
.map_err(HandshakeError::Transport)?;
// 5. Wait for CallAnswer
let answer = transport
.recv_signal()
.await?
.ok_or_else(|| anyhow::anyhow!("connection closed before receiving CallAnswer"))?;
.await
.map_err(HandshakeError::Transport)?
.ok_or(HandshakeError::ConnectionClosed)?;
let (callee_identity_pub, callee_ephemeral_pub, callee_signature, _chosen_profile) =
match answer {
@@ -63,11 +113,14 @@ pub async fn perform_handshake(
signature,
chosen_profile,
} => (identity_pub, ephemeral_pub, signature, chosen_profile),
other => {
return Err(anyhow::anyhow!(
"expected CallAnswer, got {:?}",
std::mem::discriminant(&other)
));
SignalMessage::Hangup {
reason: HangupReason::ProtocolVersionMismatch { server_supported },
..
} => {
return Err(HandshakeError::ProtocolVersionMismatch { server_supported });
}
_ => {
return Err(HandshakeError::UnexpectedSignal("CallAnswer"));
}
};
@@ -76,11 +129,13 @@ pub async fn perform_handshake(
verify_data.extend_from_slice(&callee_ephemeral_pub);
verify_data.extend_from_slice(b"call-answer");
if !WarzoneKeyExchange::verify(&callee_identity_pub, &verify_data, &callee_signature) {
return Err(anyhow::anyhow!("callee signature verification failed"));
return Err(HandshakeError::SignatureVerificationFailed);
}
// 7. Derive session
let session = kx.derive_session(&callee_ephemeral_pub)?;
let session = kx
.derive_session(&callee_ephemeral_pub)
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
Ok(session)
}

View File

@@ -156,6 +156,8 @@ async fn handshake_rejects_tampered_signature() {
signature: bad_signature,
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
};
client_transport_clone
.send_signal(&offer)
@@ -179,3 +181,41 @@ async fn handshake_rejects_tampered_signature() {
Ok(_) => panic!("relay should reject tampered signature"),
}
}
#[tokio::test]
async fn client_receives_protocol_version_mismatch() {
let (client_transport, relay_transport) = MockTransport::pair();
let client_seed = [0xAA_u8; 32];
// Spawn a fake relay that sends ProtocolVersionMismatch.
let relay_clone = Arc::clone(&relay_transport);
tokio::spawn(async move {
// Wait for the client's CallOffer.
let offer = relay_clone.recv_signal().await.unwrap().unwrap();
assert!(matches!(offer, SignalMessage::CallOffer { .. }));
// Respond with ProtocolVersionMismatch.
let mismatch = SignalMessage::Hangup {
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
server_supported: vec![3],
},
call_id: None,
};
relay_clone.send_signal(&mismatch).await.unwrap();
});
let result =
wzp_client::handshake::perform_handshake(client_transport.as_ref(), &client_seed, None)
.await;
match result {
Err(wzp_client::handshake::HandshakeError::ProtocolVersionMismatch {
server_supported,
}) => {
assert_eq!(server_supported, vec![3]);
}
Err(other) => panic!("expected ProtocolVersionMismatch, got: {other:?}"),
Ok(_) => panic!("expected handshake to fail with ProtocolVersionMismatch"),
}
}

View File

@@ -119,6 +119,8 @@ fn wzp_signal_serializes_into_fc_callsignal_payload() {
signature: vec![3u8; 64],
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
};
// Encode as featherChat CallSignal payload
@@ -301,6 +303,8 @@ fn all_signal_types_map_correctly() {
signature: vec![],
supported_profiles: vec![],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
},
"Offer",
),

View File

@@ -554,6 +554,12 @@ pub enum SignalMessage {
/// Optional display name set by the caller.
#[serde(default)]
alias: Option<String>,
/// Protocol version requested by the caller (default 2 = v2 wire format).
#[serde(default = "default_proto_version")]
protocol_version: u8,
/// Protocol versions this client supports (default [2]).
#[serde(default = "default_supported_versions")]
supported_versions: Vec<u8>,
},
/// Call acceptance (analogous to Warzone's WireMessage::CallAnswer).
@@ -1097,14 +1103,29 @@ pub struct RoomParticipant {
pub relay_label: Option<String>,
}
/// Default protocol version for `CallOffer` (v2 wire format).
pub fn default_proto_version() -> u8 {
2
}
/// Default supported versions for `CallOffer` (only v2).
pub fn default_supported_versions() -> Vec<u8> {
vec![2]
}
/// Reasons for ending a call.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum HangupReason {
Normal,
Busy,
Declined,
Timeout,
Error,
/// Server does not support any of the client's requested protocol versions.
ProtocolVersionMismatch {
/// Versions the server is willing to speak.
server_supported: Vec<u8>,
},
}
#[cfg(test)]
@@ -2024,7 +2045,10 @@ mod tests {
let pkt = make_media_packet(0, 0, b"audio");
let wire = pkt.encode_compact(&mut ctx, &mut frames_since_full);
assert_eq!(wire[0], FRAME_TYPE_FULL, "must fall back to FULL when no baseline");
assert_eq!(
wire[0], FRAME_TYPE_FULL,
"must fall back to FULL when no baseline"
);
// After the fallback the baseline is established.
assert!(ctx.last_header().is_some());
}

View File

@@ -41,6 +41,7 @@ pub async fn accept_handshake(
caller_signature,
supported_profiles,
caller_alias,
protocol_version,
) = match offer {
SignalMessage::CallOffer {
identity_pub,
@@ -48,12 +49,15 @@ pub async fn accept_handshake(
signature,
supported_profiles,
alias,
protocol_version,
supported_versions: _,
} => (
identity_pub,
ephemeral_pub,
signature,
supported_profiles,
alias,
protocol_version,
),
other => {
return Err(anyhow::anyhow!(
@@ -63,6 +67,20 @@ pub async fn accept_handshake(
}
};
// 1a. Protocol version check — we only speak v2.
if protocol_version != 2 {
let mismatch = SignalMessage::Hangup {
reason: wzp_proto::HangupReason::ProtocolVersionMismatch {
server_supported: vec![2],
},
call_id: None,
};
let _ = transport.send_signal(&mismatch).await;
return Err(anyhow::anyhow!(
"protocol version mismatch: client requested {protocol_version}, server supports [2]"
));
}
// 2. Verify caller's signature over (ephemeral_pub || "call-offer")
let mut verify_data = Vec::with_capacity(32 + 10);
verify_data.extend_from_slice(&caller_ephemeral_pub);

View File

@@ -103,6 +103,79 @@ async fn handshake_succeeds() {
drop(client_transport);
}
// -----------------------------------------------------------------------
// Test 5: handshake_rejects_v1_protocol_version
// -----------------------------------------------------------------------
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn handshake_rejects_v1_protocol_version() {
let (client_transport, server_transport, _endpoints) = connected_pair().await;
let caller_seed: [u8; 32] = [0xCC; 32];
let callee_seed: [u8; 32] = [0xDD; 32];
let server_t = Arc::clone(&server_transport);
let callee_handle =
tokio::spawn(async move { accept_handshake(server_t.as_ref(), &callee_seed).await });
// Build a v1 CallOffer (protocol_version = 1).
let mut kx = WarzoneKeyExchange::from_identity_seed(&caller_seed);
let identity_pub = kx.identity_public_key();
let ephemeral_pub = kx.generate_ephemeral();
let mut sign_data = Vec::with_capacity(32 + 10);
sign_data.extend_from_slice(&ephemeral_pub);
sign_data.extend_from_slice(b"call-offer");
let signature = kx.sign(&sign_data);
let v1_offer = SignalMessage::CallOffer {
identity_pub,
ephemeral_pub,
signature,
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None,
protocol_version: 1,
supported_versions: vec![1, 2],
};
client_transport
.send_signal(&v1_offer)
.await
.expect("send v1 CallOffer");
// The callee should return an error about protocol version mismatch.
let result = callee_handle.await.expect("join callee task");
match result {
Ok(_) => panic!("accept_handshake must reject a v1 offer"),
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("protocol version mismatch"),
"error should mention protocol version mismatch, got: {err_msg}"
);
}
}
// Verify the client received a Hangup with ProtocolVersionMismatch.
let response = client_transport
.recv_signal()
.await
.expect("recv response")
.expect("response should exist");
match response {
SignalMessage::Hangup {
reason: wzp_proto::HangupReason::ProtocolVersionMismatch { server_supported },
..
} => {
assert_eq!(server_supported, vec![2]);
}
other => panic!("expected ProtocolVersionMismatch hangup, got: {other:?}"),
}
drop(server_transport);
drop(client_transport);
}
// -----------------------------------------------------------------------
// Test 2: handshake_verifies_identity
// -----------------------------------------------------------------------
@@ -276,6 +349,8 @@ async fn handshake_rejects_bad_signature() {
signature,
supported_profiles: vec![wzp_proto::QualityProfile::GOOD],
alias: None,
protocol_version: 2,
supported_versions: vec![2],
};
client_transport