//! Double Ratchet algorithm implementation. //! Follows Signal's Double Ratchet specification. use std::collections::BTreeMap; use serde::{Deserialize, Serialize}; use x25519_dalek::{PublicKey, StaticSecret}; use crate::crypto::{aead_decrypt, aead_encrypt, hkdf_derive}; use crate::errors::ProtocolError; const MAX_SKIP: u32 = 1000; /// Current serialization version for [`RatchetState`]. const RATCHET_VERSION: u8 = 1; /// Magic byte to distinguish versioned from unversioned (legacy) data. const RATCHET_MAGIC: u8 = 0xFC; /// A message produced by the ratchet. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RatchetMessage { pub header: RatchetHeader, pub ciphertext: Vec, } /// Header included with each ratchet message. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RatchetHeader { /// Current DH ratchet public key. pub dh_public: [u8; 32], /// Number of messages in the previous sending chain. pub prev_chain_length: u32, /// Message number in the current sending chain. pub message_number: u32, } /// The Double Ratchet state machine. #[derive(Serialize, Deserialize)] pub struct RatchetState { dh_self: Vec, // StaticSecret bytes (32) dh_remote: Option<[u8; 32]>, root_key: [u8; 32], chain_key_send: Option<[u8; 32]>, chain_key_recv: Option<[u8; 32]>, send_count: u32, recv_count: u32, prev_send_count: u32, skipped: BTreeMap<([u8; 32], u32), [u8; 32]>, // (dh_pub, n) -> message_key } impl RatchetState { /// Initialize as Alice (initiator). Alice knows Bob's ratchet public key. pub fn init_alice(shared_secret: [u8; 32], bob_ratchet_pub: PublicKey) -> Self { let dh_self = StaticSecret::random_from_rng(rand::rngs::OsRng); let dh_out = dh_self.diffie_hellman(&bob_ratchet_pub); let (root_key, chain_key_send) = kdf_rk(&shared_secret, dh_out.as_bytes()); RatchetState { dh_self: dh_self.to_bytes().to_vec(), dh_remote: Some(*bob_ratchet_pub.as_bytes()), root_key, chain_key_send: Some(chain_key_send), chain_key_recv: None, send_count: 0, recv_count: 0, prev_send_count: 0, skipped: BTreeMap::new(), } } /// Initialize as Bob (responder). Bob uses his signed pre-key as initial ratchet key. pub fn init_bob(shared_secret: [u8; 32], our_ratchet_secret: StaticSecret) -> Self { RatchetState { dh_self: our_ratchet_secret.to_bytes().to_vec(), dh_remote: None, root_key: shared_secret, chain_key_send: None, chain_key_recv: None, send_count: 0, recv_count: 0, prev_send_count: 0, skipped: BTreeMap::new(), } } /// Get our current DH ratchet public key. fn dh_public(&self) -> PublicKey { let mut bytes = [0u8; 32]; bytes.copy_from_slice(&self.dh_self); let secret = StaticSecret::from(bytes); PublicKey::from(&secret) } fn dh_secret(&self) -> StaticSecret { let mut bytes = [0u8; 32]; bytes.copy_from_slice(&self.dh_self); StaticSecret::from(bytes) } /// Encrypt a plaintext message. pub fn encrypt(&mut self, plaintext: &[u8]) -> Result { // If we don't have a sending chain yet (Bob's first message), do a DH ratchet step if self.chain_key_send.is_none() { if self.dh_remote.is_none() { return Err(ProtocolError::RatchetError( "no remote DH key and no sending chain".into(), )); } self.dh_ratchet_step()?; } let ck = self .chain_key_send .as_ref() .ok_or_else(|| ProtocolError::RatchetError("no sending chain".into()))?; let (new_ck, message_key) = kdf_ck(ck); self.chain_key_send = Some(new_ck); let header = RatchetHeader { dh_public: *self.dh_public().as_bytes(), prev_chain_length: self.prev_send_count, message_number: self.send_count, }; // AAD: serialized header let aad = bincode::serialize(&header) .map_err(|e| ProtocolError::SerializationError(e.to_string()))?; let ciphertext = aead_encrypt(&message_key, plaintext, &aad); self.send_count += 1; Ok(RatchetMessage { header, ciphertext }) } /// Decrypt a received ratchet message. pub fn decrypt(&mut self, message: &RatchetMessage) -> Result, ProtocolError> { // Check skipped messages first let key = (message.header.dh_public, message.header.message_number); if let Some(mk) = self.skipped.remove(&key) { let aad = bincode::serialize(&message.header) .map_err(|e| ProtocolError::SerializationError(e.to_string()))?; return aead_decrypt(&mk, &message.ciphertext, &aad); } // If the message's DH key differs from what we have, perform DH ratchet let need_ratchet = match self.dh_remote { Some(ref remote) => *remote != message.header.dh_public, None => true, }; if need_ratchet { // Skip any missed messages in the current receiving chain if self.chain_key_recv.is_some() { self.skip_messages(message.header.prev_chain_length)?; } // DH ratchet step let their_pub = PublicKey::from(message.header.dh_public); // New receiving chain let dh_recv = self.dh_secret().diffie_hellman(&their_pub); let (rk, ck_recv) = kdf_rk(&self.root_key, dh_recv.as_bytes()); self.root_key = rk; self.chain_key_recv = Some(ck_recv); self.recv_count = 0; // New sending chain self.prev_send_count = self.send_count; self.send_count = 0; let new_dh = StaticSecret::random_from_rng(rand::rngs::OsRng); let dh_send = new_dh.diffie_hellman(&their_pub); let (rk2, ck_send) = kdf_rk(&self.root_key, dh_send.as_bytes()); self.root_key = rk2; self.chain_key_send = Some(ck_send); self.dh_self = new_dh.to_bytes().to_vec(); self.dh_remote = Some(message.header.dh_public); } // Skip to the message number self.skip_messages(message.header.message_number)?; // Derive message key let ck = self .chain_key_recv .as_ref() .ok_or_else(|| ProtocolError::RatchetError("no receiving chain".into()))?; let (new_ck, message_key) = kdf_ck(ck); self.chain_key_recv = Some(new_ck); self.recv_count += 1; let aad = bincode::serialize(&message.header) .map_err(|e| ProtocolError::SerializationError(e.to_string()))?; aead_decrypt(&message_key, &message.ciphertext, &aad) } fn skip_messages(&mut self, until: u32) -> Result<(), ProtocolError> { if self.recv_count + MAX_SKIP < until { return Err(ProtocolError::MaxSkipExceeded); } if let Some(ref ck) = self.chain_key_recv.clone() { let dh_pub = self.dh_remote.unwrap_or([0u8; 32]); let mut current_ck = *ck; while self.recv_count < until { let (new_ck, mk) = kdf_ck(¤t_ck); self.skipped.insert((dh_pub, self.recv_count), mk); current_ck = new_ck; self.recv_count += 1; } self.chain_key_recv = Some(current_ck); } Ok(()) } /// Serialize with version prefix: `[MAGIC][VERSION][bincode data]`. /// /// Use [`deserialize_versioned`](Self::deserialize_versioned) to restore. pub fn serialize_versioned(&self) -> Result, String> { let data = bincode::serialize(self) .map_err(|e| format!("serialize: {}", e))?; let mut out = Vec::with_capacity(2 + data.len()); out.push(RATCHET_MAGIC); out.push(RATCHET_VERSION); out.extend_from_slice(&data); Ok(out) } /// Deserialize with version awareness. Handles: /// - Versioned format: `[0xFC][version][bincode]` /// - Legacy format: raw bincode (no prefix) pub fn deserialize_versioned(data: &[u8]) -> Result { if data.len() >= 2 && data[0] == RATCHET_MAGIC { let version = data[1]; match version { 1 => bincode::deserialize(&data[2..]) .map_err(|e| format!("v1 deserialize: {}", e)), _ => Err(format!("unknown ratchet version: {}", version)), } } else { // Legacy: try raw bincode (pre-versioning data) bincode::deserialize(data) .map_err(|e| format!("legacy deserialize: {}", e)) } } fn dh_ratchet_step(&mut self) -> Result<(), ProtocolError> { let their_pub = self .dh_remote .map(PublicKey::from) .ok_or_else(|| ProtocolError::RatchetError("no remote key for ratchet".into()))?; self.prev_send_count = self.send_count; self.send_count = 0; let new_dh = StaticSecret::random_from_rng(rand::rngs::OsRng); let dh_out = new_dh.diffie_hellman(&their_pub); let (rk, ck_send) = kdf_rk(&self.root_key, dh_out.as_bytes()); self.root_key = rk; self.chain_key_send = Some(ck_send); self.dh_self = new_dh.to_bytes().to_vec(); Ok(()) } } /// Root key KDF: derive new root key + chain key from DH output. fn kdf_rk(root_key: &[u8; 32], dh_output: &[u8]) -> ([u8; 32], [u8; 32]) { let derived = hkdf_derive(dh_output, root_key, b"warzone-ratchet-rk", 64); let mut new_rk = [0u8; 32]; let mut chain_key = [0u8; 32]; new_rk.copy_from_slice(&derived[..32]); chain_key.copy_from_slice(&derived[32..]); (new_rk, chain_key) } /// Chain key KDF: derive new chain key + message key. fn kdf_ck(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) { let mk_bytes = hkdf_derive(chain_key, b"", b"warzone-ratchet-mk", 32); let ck_bytes = hkdf_derive(chain_key, b"", b"warzone-ratchet-ck", 32); let mut new_ck = [0u8; 32]; let mut mk = [0u8; 32]; new_ck.copy_from_slice(&ck_bytes); mk.copy_from_slice(&mk_bytes); (new_ck, mk) } #[cfg(test)] mod tests { use super::*; fn make_pair() -> (RatchetState, RatchetState) { let shared_secret = [42u8; 32]; let bob_ratchet = StaticSecret::random_from_rng(rand::rngs::OsRng); let bob_ratchet_pub = PublicKey::from(&bob_ratchet); let alice = RatchetState::init_alice(shared_secret, bob_ratchet_pub); let bob = RatchetState::init_bob(shared_secret, bob_ratchet); (alice, bob) } #[test] fn basic_exchange() { let (mut alice, mut bob) = make_pair(); let msg = alice.encrypt(b"hello bob").unwrap(); let plain = bob.decrypt(&msg).unwrap(); assert_eq!(plain, b"hello bob"); } #[test] fn bidirectional() { let (mut alice, mut bob) = make_pair(); let m1 = alice.encrypt(b"hello bob").unwrap(); assert_eq!(bob.decrypt(&m1).unwrap(), b"hello bob"); let m2 = bob.encrypt(b"hello alice").unwrap(); assert_eq!(alice.decrypt(&m2).unwrap(), b"hello alice"); let m3 = alice.encrypt(b"how are you?").unwrap(); assert_eq!(bob.decrypt(&m3).unwrap(), b"how are you?"); } #[test] fn multiple_messages_same_direction() { let (mut alice, mut bob) = make_pair(); let m1 = alice.encrypt(b"one").unwrap(); let m2 = alice.encrypt(b"two").unwrap(); let m3 = alice.encrypt(b"three").unwrap(); assert_eq!(bob.decrypt(&m1).unwrap(), b"one"); assert_eq!(bob.decrypt(&m2).unwrap(), b"two"); assert_eq!(bob.decrypt(&m3).unwrap(), b"three"); } #[test] fn out_of_order() { let (mut alice, mut bob) = make_pair(); let m1 = alice.encrypt(b"one").unwrap(); let m2 = alice.encrypt(b"two").unwrap(); let m3 = alice.encrypt(b"three").unwrap(); // Deliver out of order assert_eq!(bob.decrypt(&m3).unwrap(), b"three"); assert_eq!(bob.decrypt(&m1).unwrap(), b"one"); assert_eq!(bob.decrypt(&m2).unwrap(), b"two"); } #[test] fn versioned_serialize_roundtrip() { let (mut alice, mut bob) = make_pair(); let msg = alice.encrypt(b"test versioning").unwrap(); // Save alice with versioned format let serialized = alice.serialize_versioned().unwrap(); assert_eq!(serialized[0], 0xFC); // magic byte assert_eq!(serialized[1], 1); // version 1 // Restore and use let mut restored = RatchetState::deserialize_versioned(&serialized).unwrap(); let msg2 = restored.encrypt(b"after restore").unwrap(); let plain = bob.decrypt(&msg).unwrap(); assert_eq!(plain, b"test versioning"); let plain2 = bob.decrypt(&msg2).unwrap(); assert_eq!(plain2, b"after restore"); } #[test] fn legacy_deserialize_works() { let (alice, _) = make_pair(); // Serialize with raw bincode (legacy format) let legacy = bincode::serialize(&alice).unwrap(); // Should still deserialize with versioned reader let restored = RatchetState::deserialize_versioned(&legacy).unwrap(); assert_eq!(bincode::serialize(&restored).unwrap(), legacy); } #[test] fn many_messages() { let (mut alice, mut bob) = make_pair(); for i in 0..100 { let msg = format!("message {}", i); let encrypted = alice.encrypt(msg.as_bytes()).unwrap(); let decrypted = bob.decrypt(&encrypted).unwrap(); assert_eq!(decrypted, msg.as_bytes()); } } }