libsoliton/soliton_py/src/ratchet.rs
Kamal Tufekcic 1d99048c95
Some checks failed
CI / lint (push) Successful in 1m37s
CI / test-python (push) Successful in 1m49s
CI / test-zig (push) Successful in 1m39s
CI / test-wasm (push) Successful in 1m54s
CI / test (push) Successful in 14m44s
CI / miri (push) Successful in 14m18s
CI / build (push) Successful in 1m9s
CI / fuzz-regression (push) Successful in 9m9s
CI / publish (push) Failing after 1m10s
CI / publish-python (push) Failing after 1m46s
CI / publish-wasm (push) Has been cancelled
initial commit
Signed-off-by: Kamal Tufekcic <kamal@lo.sh>
2026-04-02 23:48:10 +03:00

349 lines
12 KiB
Rust

//! Double ratchet session: encrypt, decrypt, serialize, deserialize.
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use zeroize::Zeroize;
use crate::call::CallKeys;
use crate::errors::to_py_err;
/// Double ratchet session state.
///
/// Manages the symmetric ratchet for ongoing message encryption/decryption.
/// Initialized after KEX completes (from root_key + chain_key + peer_ek).
///
/// Use as a context manager for automatic zeroization::
///
/// with Ratchet.init_alice(root_key, chain_key, local_fp, remote_fp, peer_ek, ek_sk) as r:
/// header, ciphertext = r.encrypt(b"hello")
#[pyclass]
pub struct Ratchet {
inner: Option<soliton::ratchet::RatchetState>,
}
#[pymethods]
impl Ratchet {
/// Initialize Alice's side (initiator).
#[staticmethod]
fn init_alice(
root_key: &[u8],
chain_key: &[u8],
local_fp: &[u8],
remote_fp: &[u8],
peer_ek: &[u8],
ek_sk: &[u8],
) -> PyResult<Self> {
let mut rk = to_32("root_key", root_key)?;
let mut ck = to_32("chain_key", chain_key)?;
let lfp = to_32("local_fp", local_fp)?;
let rfp = to_32("remote_fp", remote_fp)?;
let ek_pub = soliton::primitives::xwing::PublicKey::from_bytes(peer_ek.to_vec())
.map_err(to_py_err)?;
let ek_secret =
soliton::primitives::xwing::SecretKey::from_bytes(ek_sk.to_vec()).map_err(to_py_err)?;
let state = soliton::ratchet::RatchetState::init_alice(rk, ck, lfp, rfp, ek_pub, ek_secret)
.map_err(to_py_err)?;
rk.zeroize();
ck.zeroize();
Ok(Self { inner: Some(state) })
}
/// Initialize Bob's side (responder).
#[staticmethod]
fn init_bob(
root_key: &[u8],
chain_key: &[u8],
local_fp: &[u8],
remote_fp: &[u8],
peer_ek: &[u8],
) -> PyResult<Self> {
let mut rk = to_32("root_key", root_key)?;
let mut ck = to_32("chain_key", chain_key)?;
let lfp = to_32("local_fp", local_fp)?;
let rfp = to_32("remote_fp", remote_fp)?;
let ek_pub = soliton::primitives::xwing::PublicKey::from_bytes(peer_ek.to_vec())
.map_err(to_py_err)?;
let state = soliton::ratchet::RatchetState::init_bob(rk, ck, lfp, rfp, ek_pub)
.map_err(to_py_err)?;
rk.zeroize();
ck.zeroize();
Ok(Self { inner: Some(state) })
}
/// Encrypt a plaintext message.
///
/// Returns:
/// Tuple of (header_bytes, ciphertext). header_bytes is the serialized
/// RatchetHeader (ratchet_pk + optional kem_ct + n + pn); ciphertext is
/// the AEAD output. Both must be sent to the recipient.
fn encrypt<'py>(
&mut self,
py: Python<'py>,
plaintext: &[u8],
) -> PyResult<(Py<PyBytes>, Py<PyBytes>)> {
let state = self.inner.as_mut().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
let msg = state.encrypt(plaintext).map_err(to_py_err)?;
// Serialize the header to bytes for transport.
let header_bytes = encode_header(&msg.header);
Ok((
PyBytes::new(py, &header_bytes).into(),
PyBytes::new(py, &msg.ciphertext).into(),
))
}
/// Decrypt a received message.
///
/// Args:
/// header: Serialized RatchetHeader bytes (from sender's encrypt()).
/// ciphertext: AEAD ciphertext bytes.
///
/// Returns:
/// Decrypted plaintext bytes.
fn decrypt<'py>(
&mut self,
py: Python<'py>,
header: &[u8],
ciphertext: &[u8],
) -> PyResult<Py<PyBytes>> {
let state = self.inner.as_mut().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
let rh = decode_header(header)?;
let pt = state.decrypt(&rh, ciphertext).map_err(to_py_err)?;
Ok(PyBytes::new(py, &pt).into())
}
/// Encrypt the first message (pre-ratchet, uses initial chain key).
///
/// This is a static method — called before the ratchet is initialized.
///
/// Args:
/// chain_key: 32-byte initial chain key from KEX.
/// plaintext: First application message.
/// aad: Additional authenticated data (from build_first_message_aad).
///
/// Returns:
/// Tuple of (encrypted_payload, ratchet_init_key). Pass ratchet_init_key
/// as chain_key to init_alice/init_bob.
#[staticmethod]
fn encrypt_first_message<'py>(
py: Python<'py>,
chain_key: &[u8],
plaintext: &[u8],
aad: &[u8],
) -> PyResult<(Py<PyBytes>, Py<PyBytes>)> {
let ck = zeroizing_32("chain_key", chain_key)?;
let (ct, rik) = soliton::ratchet::RatchetState::encrypt_first_message(ck, plaintext, aad)
.map_err(to_py_err)?;
Ok((PyBytes::new(py, &ct).into(), PyBytes::new(py, &*rik).into()))
}
/// Decrypt the first message (pre-ratchet).
///
/// Returns:
/// Tuple of (plaintext, ratchet_init_key).
#[staticmethod]
fn decrypt_first_message<'py>(
py: Python<'py>,
chain_key: &[u8],
encrypted_payload: &[u8],
aad: &[u8],
) -> PyResult<(Py<PyBytes>, Py<PyBytes>)> {
let ck = zeroizing_32("chain_key", chain_key)?;
let (pt, rik) =
soliton::ratchet::RatchetState::decrypt_first_message(ck, encrypted_payload, aad)
.map_err(to_py_err)?;
Ok((PyBytes::new(py, &pt).into(), PyBytes::new(py, &*rik).into()))
}
/// Serialize the ratchet state. Consumes the ratchet.
///
/// Returns:
/// Tuple of (blob, epoch). Persist the blob encrypted (e.g., with
/// StorageKeyRing). Store epoch separately for anti-rollback.
#[allow(clippy::wrong_self_convention)]
fn to_bytes<'py>(&mut self, py: Python<'py>) -> PyResult<(Py<PyBytes>, u64)> {
let state = self
.inner
.take()
.ok_or_else(|| crate::errors::InvalidDataError::new_err("ratchet already consumed"))?;
let (blob, epoch) = state.to_bytes().map_err(to_py_err)?;
Ok((PyBytes::new(py, &blob).into(), epoch))
}
/// Deserialize ratchet state with anti-rollback protection.
///
/// Args:
/// data: Serialized ratchet blob.
/// min_epoch: Minimum acceptable epoch. Use saved_epoch - 1.
#[staticmethod]
fn from_bytes(data: &[u8], min_epoch: u64) -> PyResult<Self> {
let state = soliton::ratchet::RatchetState::from_bytes_with_min_epoch(data, min_epoch)
.map_err(to_py_err)?;
Ok(Self { inner: Some(state) })
}
/// Whether the ratchet can be serialized (counters not exhausted).
fn can_serialize(&self) -> PyResult<bool> {
let state = self.inner.as_ref().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
Ok(state.can_serialize())
}
/// Current epoch number.
fn epoch(&self) -> PyResult<u64> {
let state = self.inner.as_ref().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
Ok(state.epoch())
}
/// Reset the ratchet (zeroize all keys). The session is dead after this.
fn reset(&mut self) -> PyResult<()> {
let state = self.inner.as_mut().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
state.reset();
Ok(())
}
/// Derive call keys for encrypted voice/video.
fn derive_call_keys(&self, kem_ss: &[u8], call_id: &[u8]) -> PyResult<CallKeys> {
let state = self.inner.as_ref().ok_or_else(|| {
crate::errors::InvalidDataError::new_err("ratchet consumed or closed")
})?;
if kem_ss.len() != 32 {
return Err(crate::errors::InvalidLengthError::new_err(
"kem_ss must be 32 bytes",
));
}
if call_id.len() != 16 {
return Err(crate::errors::InvalidLengthError::new_err(
"call_id must be 16 bytes",
));
}
let ss: &[u8; 32] = kem_ss.try_into().unwrap();
let cid: &[u8; 16] = call_id.try_into().unwrap();
let keys = state.derive_call_keys(ss, cid).map_err(to_py_err)?;
Ok(CallKeys::from_inner(keys))
}
fn close(&mut self) {
if let Some(mut state) = self.inner.take() {
state.reset();
}
}
fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}
fn __exit__(
&mut self,
_exc_type: Option<&Bound<'_, PyAny>>,
_exc_val: Option<&Bound<'_, PyAny>>,
_exc_tb: Option<&Bound<'_, PyAny>>,
) {
self.close();
}
}
// ── Header serialization ────────────────────────────────────────────────
//
// Simple wire format for Python: ratchet_pk (1216) + has_kem_ct (1) +
// [kem_ct (1120) if present] + n (4 BE) + pn (4 BE).
// This matches the CAPI's encode_ratchet_header layout.
fn encode_header(h: &soliton::ratchet::RatchetHeader) -> Vec<u8> {
let pk_bytes = h.ratchet_pk.as_bytes();
let has_ct = h.kem_ct.is_some();
let size = 1216 + 1 + if has_ct { 2 + 1120 } else { 0 } + 4 + 4;
let mut buf = Vec::with_capacity(size);
buf.extend_from_slice(pk_bytes);
if let Some(ref ct) = h.kem_ct {
buf.push(0x01);
let ct_bytes = ct.as_bytes();
buf.extend_from_slice(&(ct_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(ct_bytes);
} else {
buf.push(0x00);
}
buf.extend_from_slice(&h.n.to_be_bytes());
buf.extend_from_slice(&h.pn.to_be_bytes());
buf
}
fn decode_header(data: &[u8]) -> PyResult<soliton::ratchet::RatchetHeader> {
if data.len() < 1216 + 1 + 4 + 4 {
return Err(crate::errors::InvalidDataError::new_err("header too short"));
}
let ratchet_pk = soliton::primitives::xwing::PublicKey::from_bytes(data[..1216].to_vec())
.map_err(to_py_err)?;
let has_ct = data[1216];
if has_ct != 0x00 && has_ct != 0x01 {
return Err(crate::errors::InvalidDataError::new_err(
"invalid has_kem_ct flag (expected 0x00 or 0x01)",
));
}
let rest = if has_ct == 0x01 {
if data.len() < 1216 + 1 + 2 + 1120 + 4 + 4 {
return Err(crate::errors::InvalidDataError::new_err(
"header too short for kem_ct",
));
}
&data[1216 + 1 + 2 + 1120..]
} else {
&data[1216 + 1..]
};
let kem_ct = if has_ct == 0x01 {
Some(
soliton::primitives::xwing::Ciphertext::from_bytes(
data[1216 + 1 + 2..1216 + 1 + 2 + 1120].to_vec(),
)
.map_err(to_py_err)?,
)
} else {
None
};
if rest.len() < 8 {
return Err(crate::errors::InvalidDataError::new_err(
"header missing counters",
));
}
let n = u32::from_be_bytes(rest[..4].try_into().unwrap());
let pn = u32::from_be_bytes(rest[4..8].try_into().unwrap());
Ok(soliton::ratchet::RatchetHeader {
ratchet_pk,
kem_ct,
n,
pn,
})
}
// ── Helpers ─────────────────────────────────────────────────────────────
fn to_32(name: &str, data: &[u8]) -> PyResult<[u8; 32]> {
data.try_into()
.map_err(|_| crate::errors::InvalidLengthError::new_err(format!("{name} must be 32 bytes")))
}
fn zeroizing_32(name: &str, data: &[u8]) -> PyResult<zeroize::Zeroizing<[u8; 32]>> {
let arr = to_32(name, data)?;
Ok(zeroize::Zeroizing::new(arr))
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Ratchet>()?;
Ok(())
}