//! Double ratchet session: encrypt, decrypt, serialize, deserialize. use pyo3::prelude::*; use pyo3::types::PyBytes; use zeroize::Zeroize; use crate::call::CallKeys; use crate::errors::to_py_err; /// Double ratchet session state. /// /// Manages the symmetric ratchet for ongoing message encryption/decryption. /// Initialized after KEX completes (from root_key + chain_key + peer_ek). /// /// Use as a context manager for automatic zeroization:: /// /// with Ratchet.init_alice(root_key, chain_key, local_fp, remote_fp, peer_ek, ek_sk) as r: /// header, ciphertext = r.encrypt(b"hello") #[pyclass] pub struct Ratchet { inner: Option, } #[pymethods] impl Ratchet { /// Initialize Alice's side (initiator). #[staticmethod] fn init_alice( root_key: &[u8], chain_key: &[u8], local_fp: &[u8], remote_fp: &[u8], peer_ek: &[u8], ek_sk: &[u8], ) -> PyResult { let mut rk = to_32("root_key", root_key)?; let mut ck = to_32("chain_key", chain_key)?; let lfp = to_32("local_fp", local_fp)?; let rfp = to_32("remote_fp", remote_fp)?; let ek_pub = soliton::primitives::xwing::PublicKey::from_bytes(peer_ek.to_vec()) .map_err(to_py_err)?; let ek_secret = soliton::primitives::xwing::SecretKey::from_bytes(ek_sk.to_vec()).map_err(to_py_err)?; let state = soliton::ratchet::RatchetState::init_alice(rk, ck, lfp, rfp, ek_pub, ek_secret) .map_err(to_py_err)?; rk.zeroize(); ck.zeroize(); Ok(Self { inner: Some(state) }) } /// Initialize Bob's side (responder). #[staticmethod] fn init_bob( root_key: &[u8], chain_key: &[u8], local_fp: &[u8], remote_fp: &[u8], peer_ek: &[u8], ) -> PyResult { let mut rk = to_32("root_key", root_key)?; let mut ck = to_32("chain_key", chain_key)?; let lfp = to_32("local_fp", local_fp)?; let rfp = to_32("remote_fp", remote_fp)?; let ek_pub = soliton::primitives::xwing::PublicKey::from_bytes(peer_ek.to_vec()) .map_err(to_py_err)?; let state = soliton::ratchet::RatchetState::init_bob(rk, ck, lfp, rfp, ek_pub) .map_err(to_py_err)?; rk.zeroize(); ck.zeroize(); Ok(Self { inner: Some(state) }) } /// Encrypt a plaintext message. /// /// Returns: /// Tuple of (header_bytes, ciphertext). header_bytes is the serialized /// RatchetHeader (ratchet_pk + optional kem_ct + n + pn); ciphertext is /// the AEAD output. Both must be sent to the recipient. fn encrypt<'py>( &mut self, py: Python<'py>, plaintext: &[u8], ) -> PyResult<(Py, Py)> { let state = self.inner.as_mut().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; let msg = state.encrypt(plaintext).map_err(to_py_err)?; // Serialize the header to bytes for transport. let header_bytes = encode_header(&msg.header); Ok(( PyBytes::new(py, &header_bytes).into(), PyBytes::new(py, &msg.ciphertext).into(), )) } /// Decrypt a received message. /// /// Args: /// header: Serialized RatchetHeader bytes (from sender's encrypt()). /// ciphertext: AEAD ciphertext bytes. /// /// Returns: /// Decrypted plaintext bytes. fn decrypt<'py>( &mut self, py: Python<'py>, header: &[u8], ciphertext: &[u8], ) -> PyResult> { let state = self.inner.as_mut().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; let rh = decode_header(header)?; let pt = state.decrypt(&rh, ciphertext).map_err(to_py_err)?; Ok(PyBytes::new(py, &pt).into()) } /// Encrypt the first message (pre-ratchet, uses initial chain key). /// /// This is a static method — called before the ratchet is initialized. /// /// Args: /// chain_key: 32-byte initial chain key from KEX. /// plaintext: First application message. /// aad: Additional authenticated data (from build_first_message_aad). /// /// Returns: /// Tuple of (encrypted_payload, ratchet_init_key). Pass ratchet_init_key /// as chain_key to init_alice/init_bob. #[staticmethod] fn encrypt_first_message<'py>( py: Python<'py>, chain_key: &[u8], plaintext: &[u8], aad: &[u8], ) -> PyResult<(Py, Py)> { let ck = zeroizing_32("chain_key", chain_key)?; let (ct, rik) = soliton::ratchet::RatchetState::encrypt_first_message(ck, plaintext, aad) .map_err(to_py_err)?; Ok((PyBytes::new(py, &ct).into(), PyBytes::new(py, &*rik).into())) } /// Decrypt the first message (pre-ratchet). /// /// Returns: /// Tuple of (plaintext, ratchet_init_key). #[staticmethod] fn decrypt_first_message<'py>( py: Python<'py>, chain_key: &[u8], encrypted_payload: &[u8], aad: &[u8], ) -> PyResult<(Py, Py)> { let ck = zeroizing_32("chain_key", chain_key)?; let (pt, rik) = soliton::ratchet::RatchetState::decrypt_first_message(ck, encrypted_payload, aad) .map_err(to_py_err)?; Ok((PyBytes::new(py, &pt).into(), PyBytes::new(py, &*rik).into())) } /// Serialize the ratchet state. Consumes the ratchet. /// /// Returns: /// Tuple of (blob, epoch). Persist the blob encrypted (e.g., with /// StorageKeyRing). Store epoch separately for anti-rollback. #[allow(clippy::wrong_self_convention)] fn to_bytes<'py>(&mut self, py: Python<'py>) -> PyResult<(Py, u64)> { let state = self .inner .take() .ok_or_else(|| crate::errors::InvalidDataError::new_err("ratchet already consumed"))?; let (blob, epoch) = state.to_bytes().map_err(to_py_err)?; Ok((PyBytes::new(py, &blob).into(), epoch)) } /// Deserialize ratchet state with anti-rollback protection. /// /// Args: /// data: Serialized ratchet blob. /// min_epoch: Minimum acceptable epoch. Use saved_epoch - 1. #[staticmethod] fn from_bytes(data: &[u8], min_epoch: u64) -> PyResult { let state = soliton::ratchet::RatchetState::from_bytes_with_min_epoch(data, min_epoch) .map_err(to_py_err)?; Ok(Self { inner: Some(state) }) } /// Whether the ratchet can be serialized (counters not exhausted). fn can_serialize(&self) -> PyResult { let state = self.inner.as_ref().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; Ok(state.can_serialize()) } /// Current epoch number. fn epoch(&self) -> PyResult { let state = self.inner.as_ref().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; Ok(state.epoch()) } /// Reset the ratchet (zeroize all keys). The session is dead after this. fn reset(&mut self) -> PyResult<()> { let state = self.inner.as_mut().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; state.reset(); Ok(()) } /// Derive call keys for encrypted voice/video. fn derive_call_keys(&self, kem_ss: &[u8], call_id: &[u8]) -> PyResult { let state = self.inner.as_ref().ok_or_else(|| { crate::errors::InvalidDataError::new_err("ratchet consumed or closed") })?; if kem_ss.len() != 32 { return Err(crate::errors::InvalidLengthError::new_err( "kem_ss must be 32 bytes", )); } if call_id.len() != 16 { return Err(crate::errors::InvalidLengthError::new_err( "call_id must be 16 bytes", )); } let ss: &[u8; 32] = kem_ss.try_into().unwrap(); let cid: &[u8; 16] = call_id.try_into().unwrap(); let keys = state.derive_call_keys(ss, cid).map_err(to_py_err)?; Ok(CallKeys::from_inner(keys)) } fn close(&mut self) { if let Some(mut state) = self.inner.take() { state.reset(); } } fn __enter__(slf: Py) -> Py { slf } fn __exit__( &mut self, _exc_type: Option<&Bound<'_, PyAny>>, _exc_val: Option<&Bound<'_, PyAny>>, _exc_tb: Option<&Bound<'_, PyAny>>, ) { self.close(); } } // ── Header serialization ──────────────────────────────────────────────── // // Simple wire format for Python: ratchet_pk (1216) + has_kem_ct (1) + // [kem_ct (1120) if present] + n (4 BE) + pn (4 BE). // This matches the CAPI's encode_ratchet_header layout. fn encode_header(h: &soliton::ratchet::RatchetHeader) -> Vec { let pk_bytes = h.ratchet_pk.as_bytes(); let has_ct = h.kem_ct.is_some(); let size = 1216 + 1 + if has_ct { 2 + 1120 } else { 0 } + 4 + 4; let mut buf = Vec::with_capacity(size); buf.extend_from_slice(pk_bytes); if let Some(ref ct) = h.kem_ct { buf.push(0x01); let ct_bytes = ct.as_bytes(); buf.extend_from_slice(&(ct_bytes.len() as u16).to_be_bytes()); buf.extend_from_slice(ct_bytes); } else { buf.push(0x00); } buf.extend_from_slice(&h.n.to_be_bytes()); buf.extend_from_slice(&h.pn.to_be_bytes()); buf } fn decode_header(data: &[u8]) -> PyResult { if data.len() < 1216 + 1 + 4 + 4 { return Err(crate::errors::InvalidDataError::new_err("header too short")); } let ratchet_pk = soliton::primitives::xwing::PublicKey::from_bytes(data[..1216].to_vec()) .map_err(to_py_err)?; let has_ct = data[1216]; if has_ct != 0x00 && has_ct != 0x01 { return Err(crate::errors::InvalidDataError::new_err( "invalid has_kem_ct flag (expected 0x00 or 0x01)", )); } let rest = if has_ct == 0x01 { if data.len() < 1216 + 1 + 2 + 1120 + 4 + 4 { return Err(crate::errors::InvalidDataError::new_err( "header too short for kem_ct", )); } &data[1216 + 1 + 2 + 1120..] } else { &data[1216 + 1..] }; let kem_ct = if has_ct == 0x01 { Some( soliton::primitives::xwing::Ciphertext::from_bytes( data[1216 + 1 + 2..1216 + 1 + 2 + 1120].to_vec(), ) .map_err(to_py_err)?, ) } else { None }; if rest.len() < 8 { return Err(crate::errors::InvalidDataError::new_err( "header missing counters", )); } let n = u32::from_be_bytes(rest[..4].try_into().unwrap()); let pn = u32::from_be_bytes(rest[4..8].try_into().unwrap()); Ok(soliton::ratchet::RatchetHeader { ratchet_pk, kem_ct, n, pn, }) } // ── Helpers ───────────────────────────────────────────────────────────── fn to_32(name: &str, data: &[u8]) -> PyResult<[u8; 32]> { data.try_into() .map_err(|_| crate::errors::InvalidLengthError::new_err(format!("{name} must be 32 bytes"))) } fn zeroizing_32(name: &str, data: &[u8]) -> PyResult> { let arr = to_32(name, data)?; Ok(zeroize::Zeroizing::new(arr)) } pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) }