//! Sender Keys for efficient group encryption. //! //! Instead of encrypting per-member (O(N)), each member generates a //! symmetric "sender key" and distributes it to all group members via //! 1:1 encrypted channels. Group messages are encrypted ONCE with the //! sender's key, and the same ciphertext is delivered to all members. //! //! Key rotation: on member join/leave, all members rotate their sender keys. use serde::{Deserialize, Serialize}; use crate::crypto::{aead_decrypt, aead_encrypt, hkdf_derive}; use crate::errors::ProtocolError; /// A sender key: symmetric key + chain for forward ratcheting. #[derive(Clone, Serialize, Deserialize)] pub struct SenderKey { /// Who owns this key. pub owner_fingerprint: String, /// Group this key belongs to. pub group_name: String, /// Current chain key (ratchets forward on each message). pub chain_key: [u8; 32], /// Message counter. pub counter: u32, /// Generation (incremented on rotation). pub generation: u32, } impl SenderKey { /// Generate a new sender key for a group. pub fn generate(owner_fingerprint: &str, group_name: &str) -> Self { let mut chain_key = [0u8; 32]; rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut chain_key); SenderKey { owner_fingerprint: owner_fingerprint.to_string(), group_name: group_name.to_string(), chain_key, counter: 0, generation: 0, } } /// Rotate: new random chain key, increment generation. pub fn rotate(&mut self) { rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut self.chain_key); self.counter = 0; self.generation += 1; } /// Derive a message key from the current chain key, then ratchet forward. fn derive_message_key(&mut self) -> [u8; 32] { let info = format!("wz-sk-msg-{}-{}", self.generation, self.counter); let mk_bytes = hkdf_derive(&self.chain_key, b"", info.as_bytes(), 32); let mut message_key = [0u8; 32]; message_key.copy_from_slice(&mk_bytes); // Ratchet chain key forward let ck_bytes = hkdf_derive(&self.chain_key, b"", b"wz-sk-chain", 32); self.chain_key.copy_from_slice(&ck_bytes); self.counter += 1; message_key } /// Encrypt a message with this sender key. pub fn encrypt(&mut self, plaintext: &[u8]) -> SenderKeyMessage { let message_key = self.derive_message_key(); let aad = format!("{}:{}:{}", self.group_name, self.generation, self.counter - 1); let ciphertext = aead_encrypt(&message_key, plaintext, aad.as_bytes()); SenderKeyMessage { sender_fingerprint: self.owner_fingerprint.clone(), group_name: self.group_name.clone(), generation: self.generation, counter: self.counter - 1, ciphertext, } } /// Decrypt a message from another member using their sender key. /// `self` is the RECEIVER's copy of the SENDER's key. pub fn decrypt(&mut self, msg: &SenderKeyMessage) -> Result, ProtocolError> { // Fast-forward chain if needed (handle skipped messages) if msg.generation != self.generation { return Err(ProtocolError::RatchetError(format!( "generation mismatch: expected {}, got {}", self.generation, msg.generation ))); } // We need to advance to the right counter while self.counter < msg.counter { // Skip this message key (lost message) let _ = self.derive_message_key(); } if self.counter != msg.counter { return Err(ProtocolError::RatchetError("counter mismatch".into())); } let message_key = self.derive_message_key(); let aad = format!("{}:{}:{}", msg.group_name, msg.generation, msg.counter); aead_decrypt(&message_key, &msg.ciphertext, aad.as_bytes()) } } /// An encrypted group message using sender keys. #[derive(Clone, Serialize, Deserialize)] pub struct SenderKeyMessage { pub sender_fingerprint: String, pub group_name: String, pub generation: u32, pub counter: u32, pub ciphertext: Vec, } /// Distribution message: sent via 1:1 encrypted channel to share a sender key. #[derive(Clone, Serialize, Deserialize)] pub struct SenderKeyDistribution { pub sender_fingerprint: String, pub group_name: String, pub chain_key: [u8; 32], pub generation: u32, } impl From<&SenderKey> for SenderKeyDistribution { fn from(sk: &SenderKey) -> Self { SenderKeyDistribution { sender_fingerprint: sk.owner_fingerprint.clone(), group_name: sk.group_name.clone(), chain_key: sk.chain_key, generation: sk.generation, } } } impl SenderKeyDistribution { /// Convert distribution into a receiver's copy of the sender key. pub fn into_sender_key(self) -> SenderKey { SenderKey { owner_fingerprint: self.sender_fingerprint, group_name: self.group_name, chain_key: self.chain_key, counter: 0, generation: self.generation, } } } #[cfg(test)] mod tests { use super::*; #[test] fn basic_encrypt_decrypt() { let mut alice_key = SenderKey::generate("alice", "ops"); // Bob gets a copy of Alice's key (via distribution) let dist = SenderKeyDistribution::from(&alice_key); let mut bob_copy = dist.into_sender_key(); let msg = alice_key.encrypt(b"hello group"); let plain = bob_copy.decrypt(&msg).unwrap(); assert_eq!(plain, b"hello group"); } #[test] fn multiple_messages() { let mut alice_key = SenderKey::generate("alice", "ops"); let dist = SenderKeyDistribution::from(&alice_key); let mut bob_copy = dist.into_sender_key(); for i in 0..10 { let msg = alice_key.encrypt(format!("msg {}", i).as_bytes()); let plain = bob_copy.decrypt(&msg).unwrap(); assert_eq!(plain, format!("msg {}", i).as_bytes()); } } #[test] fn rotation() { let mut alice_key = SenderKey::generate("alice", "ops"); let dist1 = SenderKeyDistribution::from(&alice_key); let mut bob_copy = dist1.into_sender_key(); let msg1 = alice_key.encrypt(b"before rotation"); let _ = bob_copy.decrypt(&msg1).unwrap(); // Rotate alice_key.rotate(); let dist2 = SenderKeyDistribution::from(&alice_key); let mut bob_copy2 = dist2.into_sender_key(); let msg2 = alice_key.encrypt(b"after rotation"); let plain = bob_copy2.decrypt(&msg2).unwrap(); assert_eq!(plain, b"after rotation"); } #[test] fn old_key_cant_decrypt_new() { let mut alice_key = SenderKey::generate("alice", "ops"); let dist = SenderKeyDistribution::from(&alice_key); let mut bob_old = dist.into_sender_key(); alice_key.rotate(); let msg = alice_key.encrypt(b"new generation"); assert!(bob_old.decrypt(&msg).is_err()); } }