refactor(cipher): remove contiguous memory assumption (#864)

* refactor(cipher): remove contiguous memory assumption

* fix mpc-tls and upstream crates
This commit is contained in:
sinu.eth
2025-05-13 18:41:55 +02:00
committed by GitHub
parent a8bf1026ca
commit 5a188e75c7
6 changed files with 144 additions and 84 deletions

View File

@@ -54,18 +54,13 @@ impl ZkAesCtr {
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.aes.alloc_keystream(vm, len)?;
let output = keystream.apply(vm, input)?;
match self.role {
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
}
self.state = State::Ready {
input,
keystream,
output,
};
self.state = State::Ready { input, keystream };
Ok(())
}
@@ -96,12 +91,7 @@ impl ZkAesCtr {
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
let State::Ready {
input,
keystream,
output,
} = &mut self.state
else {
let State::Ready { input, keystream } = &mut self.state else {
Err(ErrorRepr::State {
reason: "must be in ready state to encrypt",
})?
@@ -128,7 +118,7 @@ impl ZkAesCtr {
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = output.split_off(output.len() - padded_len);
let mut output = keystream.apply(vm, input)?;
// Assign counter block inputs.
let mut ctr = START_CTR..;
@@ -158,7 +148,6 @@ enum State {
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
output: Vector<U8>,
},
Error,
}

View File

@@ -15,9 +15,9 @@ use async_trait::async_trait;
use mpz_circuits::circuits::xor;
use mpz_memory_core::{
binary::{Binary, U8},
FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
};
use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm};
use mpz_vm_core::{prelude::*, Call, CallBuilder, CallError, Vm};
use std::{collections::VecDeque, sync::Arc};
/// Provides computation of 2PC ciphers in counter and ECB mode.
@@ -116,25 +116,7 @@ where
O: Repr<Binary> + StaticSize<Binary> + Copy,
{
/// Creates a new keystream from the provided blocks.
///
/// # Panics
///
/// * If the output of the keystream is not ordered and contiguous in
/// memory.
pub fn new(blocks: &[CtrBlock<N, C, O>]) -> Self {
let mut pos = blocks
.first()
.map(|block| block.output.to_raw().ptr().as_usize())
.unwrap_or(0);
for block in blocks {
if block.output.to_raw().ptr().as_usize() != pos {
panic!("output of keystream blocks must be ordered and contiguous in memory");
}
pos += O::SIZE;
}
Self {
blocks: VecDeque::from_iter(blocks.iter().copied()),
}
@@ -177,7 +159,7 @@ where
return Err(CipherError::new("no keystream material available"));
}
let xor = Arc::new(xor(8 * self.block_size()));
let xor = Arc::new(xor(self.block_size() * 8));
let mut pos = 0;
let mut outputs = Vec::with_capacity(self.blocks.len());
for block in &self.blocks {
@@ -194,20 +176,17 @@ where
pos += self.block_size();
}
// Calls were performed contiguously, so the output data is contiguous.
let ptr = outputs
.first()
.map(|output| output.to_raw().ptr())
.expect("keystream is not empty");
let size = self.blocks.len() * O::SIZE;
let output = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
let output = flatten_blocks(vm, outputs.iter().map(|block| block.to_raw()))?;
Ok(output)
}
/// Returns `len` bytes of the keystream as a vector.
pub fn to_vector(&self, len: usize) -> Result<Vector<U8>, CipherError> {
pub fn to_vector(
&self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Vector<U8>, CipherError> {
if len == 0 {
return Err(CipherError::new("length must be greater than 0"));
} else if self.blocks.is_empty() {
@@ -219,14 +198,8 @@ where
return Err(CipherError::new("length does not match keystream length"));
}
let ptr = self
.blocks
.front()
.map(|block| block.output.to_raw().ptr())
.expect("block count should be greater than 0");
let size = block_count * O::SIZE;
let mut keystream = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
let mut keystream =
flatten_blocks(vm, self.blocks.iter().map(|block| block.output.to_raw()))?;
keystream.truncate(len);
Ok(keystream)
@@ -272,6 +245,34 @@ where
}
}
fn flatten_blocks(
vm: &mut dyn Vm<Binary>,
blocks: impl IntoIterator<Item = Slice>,
) -> Result<Vector<U8>, CipherError> {
use mpz_circuits::CircuitBuilder;
let blocks = blocks.into_iter().collect::<Vec<_>>();
let len: usize = blocks.iter().map(|block| block.len()).sum();
let mut builder = CircuitBuilder::new();
for _ in 0..len {
let i = builder.add_input();
let o = builder.add_id_gate(i);
builder.add_output(o);
}
let circuit = builder.build().expect("flatten circuit should be valid");
let mut builder = Call::builder(Arc::new(circuit));
for block in blocks {
builder = builder.arg(block);
}
let call = builder.build().map_err(CipherError::new)?;
vm.call(call).map_err(CipherError::new)
}
/// A cipher error.
#[derive(Debug, thiserror::Error)]
#[error("{source}")]

View File

@@ -12,7 +12,7 @@ use async_trait::async_trait;
use mpz_common::Context;
use mpz_core::bitvec::BitVec;
use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Slice, View},
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{Difference, RangeSet, UnionMut};
@@ -85,10 +85,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
self.zk.clone().try_lock_owned().unwrap()
}
/// Translates a slice from the MPC VM address space to the ZK VM address
/// Translates a value from the MPC VM address space to the ZK VM address
/// space.
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
self.memory_map.try_get(slice)
pub fn translate<T: Repr<Binary>>(&self, value: T) -> Result<T, VmError> {
self.memory_map.try_get(value.to_raw()).map(T::from_raw)
}
#[cfg(test)]

View File

@@ -33,7 +33,6 @@ enum State {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
output: Vector<U8>,
ghash_key: OneTimePadShared<[u8; 16]>,
ghash: Box<dyn Ghash + Send + Sync>,
},
@@ -41,7 +40,6 @@ enum State {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
output: Vector<U8>,
ghash: Arc<dyn Ghash + Send + Sync>,
},
Error,
@@ -125,13 +123,11 @@ impl MpcAesGcm {
}
let keystream = self.aes.alloc_keystream(vm, len)?;
let output = keystream.apply(vm, input)?;
self.state = State::Setup {
input,
keystream,
j0s,
output,
ghash,
ghash_key,
};
@@ -162,7 +158,6 @@ impl MpcAesGcm {
input,
keystream,
j0s,
output,
mut ghash,
ghash_key,
} = self.state.take()
@@ -178,7 +173,6 @@ impl MpcAesGcm {
input,
keystream,
j0s,
output,
ghash: Arc::from(ghash),
};
@@ -202,10 +196,7 @@ impl MpcAesGcm {
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), AeadError> {
let State::Ready {
input,
keystream,
output,
..
input, keystream, ..
} = &mut self.state
else {
return Err(AeadError::state(
@@ -235,7 +226,7 @@ impl MpcAesGcm {
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = output.split_off(output.len() - padded_len);
let mut output = keystream.apply(vm, input)?;
// Assign counter block inputs.
let mut ctr = START_CTR..;
@@ -273,10 +264,7 @@ impl MpcAesGcm {
len: usize,
) -> Result<Vector<U8>, AeadError> {
let State::Ready {
input,
keystream,
output,
..
input, keystream, ..
} = &mut self.state
else {
return Err(AeadError::state("must be in ready state to take keystream"));
@@ -301,11 +289,6 @@ impl MpcAesGcm {
)));
}
// Drop the input and output text, we won't be needing them.
// This leaves them allocated but unassigned in the VM.
_ = input.split_off(input.len() - padded_len);
_ = output.split_off(output.len() - padded_len);
let keystream = keystream.consume(len)?;
// Assign counter block inputs.
@@ -314,7 +297,7 @@ impl MpcAesGcm {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
Ok(keystream.to_vector(len)?)
Ok(keystream.to_vector(vm, len)?)
}
/// Computes tags for the provided ciphertext. See

View File

@@ -20,7 +20,7 @@ use mpz_garble_core::Delta;
use state::{Notarize, Prove};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use serio::SinkExt;
use std::sync::Arc;
@@ -28,7 +28,12 @@ use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{bind_client, TlsConnection};
use tls_core::msgs::enums::ContentType;
use tlsn_common::{
commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
commit::commit_records,
context::build_mt_context,
mux::attach_mux,
transcript::{Record, TlsTranscript},
zk_aes::ZkAesCtr,
Role,
};
use tlsn_core::{
connection::{
@@ -103,7 +108,9 @@ impl Prover<state::Initialized> {
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
// Allocate resources for MPC-TLS in VM.
let keys = mpc_tls.alloc()?;
let mut keys = mpc_tls.alloc()?;
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
// Allocate for committing to plaintext.
let mut zk_aes = ZkAesCtr::new(Role::Prover);
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
@@ -205,6 +212,8 @@ impl Prover<state::Setup> {
{
let mut vm = vm.try_lock().expect("VM should not be locked");
translate_transcript(&mut data.transcript, &vm)?;
// Prove received plaintext. Prover drops the proof output, as they trust
// themselves.
_ = commit_records(
@@ -414,3 +423,36 @@ impl ProverControl {
.map_err(ProverError::from)
}
}
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(ProverError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(ProverError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(ProverError::mpc)?;
Ok(())
}
/// Translates VM references to the ZK address space.
fn translate_transcript<Mpc, Zk>(
transcript: &mut TlsTranscript,
vm: &Deap<Mpc, Zk>,
) -> Result<(), ProverError> {
for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
{
if let Some(plaintext_ref) = plaintext_ref.as_mut() {
*plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
}
}
Ok(())
}

View File

@@ -16,7 +16,7 @@ pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderErr
pub use error::VerifierError;
use futures::{AsyncRead, AsyncWrite};
use mpc_tls::{FollowerData, MpcTlsFollower};
use mpc_tls::{FollowerData, MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
@@ -24,8 +24,13 @@ use serio::stream::IoStreamExt;
use state::{Notarize, Verify};
use tls_core::msgs::enums::ContentType;
use tlsn_common::{
commit::commit_records, config::ProtocolConfig, context::build_mt_context, mux::attach_mux,
zk_aes::ZkAesCtr, Role,
commit::commit_records,
config::ProtocolConfig,
context::build_mt_context,
mux::attach_mux,
transcript::{Record, TlsTranscript},
zk_aes::ZkAesCtr,
Role,
};
use tlsn_core::{
attestation::{Attestation, AttestationConfig},
@@ -110,7 +115,9 @@ impl Verifier<state::Initialized> {
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
// Allocate resources for MPC-TLS in VM.
let keys = mpc_tls.alloc()?;
let mut keys = mpc_tls.alloc()?;
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
// Allocate for committing to plaintext.
let mut zk_aes = ZkAesCtr::new(Role::Verifier);
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
@@ -222,6 +229,8 @@ impl Verifier<state::Setup> {
{
let mut vm = vm.try_lock().expect("VM should not be locked");
translate_transcript(&mut transcript, &vm)?;
// Prepare for the prover to prove received plaintext.
let proof = commit_records(
&mut (*vm.zk()),
@@ -370,3 +379,39 @@ fn build_mpc_tls(
),
)
}
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(
keys: &mut SessionKeys,
vm: &Deap<Mpc, Zk>,
) -> Result<(), VerifierError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(VerifierError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(VerifierError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(VerifierError::mpc)?;
Ok(())
}
/// Translates VM references to the ZK address space.
fn translate_transcript<Mpc, Zk>(
transcript: &mut TlsTranscript,
vm: &Deap<Mpc, Zk>,
) -> Result<(), VerifierError> {
for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
{
if let Some(plaintext_ref) = plaintext_ref.as_mut() {
*plaintext_ref = vm.translate(*plaintext_ref).map_err(VerifierError::mpc)?;
}
}
Ok(())
}