use chacha20poly1305::{ aead::{Aead, KeyInit}, ChaCha20Poly1305, Nonce, }; use hkdf::Hkdf; use sha2::Sha256; use crate::errors::ProtocolError; /// HKDF-SHA256 key derivation. pub fn hkdf_derive(ikm: &[u8], salt: &[u8], info: &[u8], len: usize) -> Vec { let salt = if salt.is_empty() { None } else { Some(salt) }; let hk = Hkdf::::new(salt, ikm); let mut output = vec![0u8; len]; hk.expand(info, &mut output) .expect("HKDF output length should be valid"); output } /// Encrypt with ChaCha20-Poly1305. Returns nonce (12 bytes) || ciphertext. pub fn aead_encrypt(key: &[u8; 32], plaintext: &[u8], aad: &[u8]) -> Vec { let cipher = ChaCha20Poly1305::new(key.into()); let mut nonce_bytes = [0u8; 12]; rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let ciphertext = cipher .encrypt(nonce, chacha20poly1305::aead::Payload { msg: plaintext, aad }) .expect("encryption should not fail"); let mut result = Vec::with_capacity(12 + ciphertext.len()); result.extend_from_slice(&nonce_bytes); result.extend_from_slice(&ciphertext); result } /// Decrypt ChaCha20-Poly1305. Input: nonce (12 bytes) || ciphertext. pub fn aead_decrypt(key: &[u8; 32], data: &[u8], aad: &[u8]) -> Result, ProtocolError> { if data.len() < 12 { return Err(ProtocolError::DecryptionFailed); } let (nonce_bytes, ciphertext) = data.split_at(12); let cipher = ChaCha20Poly1305::new(key.into()); let nonce = Nonce::from_slice(nonce_bytes); cipher .decrypt(nonce, chacha20poly1305::aead::Payload { msg: ciphertext, aad }) .map_err(|_| ProtocolError::DecryptionFailed) } #[cfg(test)] mod tests { use super::*; #[test] fn aead_roundtrip() { let key = [42u8; 32]; let plaintext = b"hello warzone"; let aad = b"associated data"; let encrypted = aead_encrypt(&key, plaintext, aad); let decrypted = aead_decrypt(&key, &encrypted, aad).unwrap(); assert_eq!(decrypted, plaintext); } #[test] fn aead_wrong_key_fails() { let key = [42u8; 32]; let wrong_key = [99u8; 32]; let encrypted = aead_encrypt(&key, b"secret", b""); assert!(aead_decrypt(&wrong_key, &encrypted, b"").is_err()); } #[test] fn aead_wrong_aad_fails() { let key = [42u8; 32]; let encrypted = aead_encrypt(&key, b"secret", b"aad1"); assert!(aead_decrypt(&key, &encrypted, b"aad2").is_err()); } #[test] fn hkdf_deterministic() { let a = hkdf_derive(b"input", b"salt", b"info", 32); let b = hkdf_derive(b"input", b"salt", b"info", 32); assert_eq!(a, b); } }