diff --git a/crates/common/src/zk_aes.rs b/crates/common/src/zk_aes.rs index ecc02035a..6c7a403a8 100644 --- a/crates/common/src/zk_aes.rs +++ b/crates/common/src/zk_aes.rs @@ -54,18 +54,13 @@ impl ZkAesCtr { let input = vm.alloc_vec::(len).map_err(ZkAesCtrError::vm)?; let keystream = self.aes.alloc_keystream(vm, len)?; - let output = keystream.apply(vm, input)?; match self.role { Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?, Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?, } - self.state = State::Ready { - input, - keystream, - output, - }; + self.state = State::Ready { input, keystream }; Ok(()) } @@ -96,12 +91,7 @@ impl ZkAesCtr { explicit_nonce: Vec, len: usize, ) -> Result<(Vector, Vector), ZkAesCtrError> { - let State::Ready { - input, - keystream, - output, - } = &mut self.state - else { + let State::Ready { input, keystream } = &mut self.state else { Err(ErrorRepr::State { reason: "must be in ready state to encrypt", })? @@ -128,7 +118,7 @@ impl ZkAesCtr { let mut input = input.split_off(input.len() - padded_len); let keystream = keystream.consume(padded_len)?; - let mut output = output.split_off(output.len() - padded_len); + let mut output = keystream.apply(vm, input)?; // Assign counter block inputs. let mut ctr = START_CTR..; @@ -158,7 +148,6 @@ enum State { Ready { input: Vector, keystream: Keystream, - output: Vector, }, Error, } diff --git a/crates/components/cipher/src/lib.rs b/crates/components/cipher/src/lib.rs index b824cb049..9f7b2b3f5 100644 --- a/crates/components/cipher/src/lib.rs +++ b/crates/components/cipher/src/lib.rs @@ -15,9 +15,9 @@ use async_trait::async_trait; use mpz_circuits::circuits::xor; use mpz_memory_core::{ binary::{Binary, U8}, - FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector, + MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector, }; -use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm}; +use mpz_vm_core::{prelude::*, Call, CallBuilder, CallError, Vm}; use std::{collections::VecDeque, sync::Arc}; /// Provides computation of 2PC ciphers in counter and ECB mode. @@ -116,25 +116,7 @@ where O: Repr + StaticSize + Copy, { /// Creates a new keystream from the provided blocks. - /// - /// # Panics - /// - /// * If the output of the keystream is not ordered and contiguous in - /// memory. pub fn new(blocks: &[CtrBlock]) -> Self { - let mut pos = blocks - .first() - .map(|block| block.output.to_raw().ptr().as_usize()) - .unwrap_or(0); - - for block in blocks { - if block.output.to_raw().ptr().as_usize() != pos { - panic!("output of keystream blocks must be ordered and contiguous in memory"); - } - - pos += O::SIZE; - } - Self { blocks: VecDeque::from_iter(blocks.iter().copied()), } @@ -177,7 +159,7 @@ where return Err(CipherError::new("no keystream material available")); } - let xor = Arc::new(xor(8 * self.block_size())); + let xor = Arc::new(xor(self.block_size() * 8)); let mut pos = 0; let mut outputs = Vec::with_capacity(self.blocks.len()); for block in &self.blocks { @@ -194,20 +176,17 @@ where pos += self.block_size(); } - // Calls were performed contiguously, so the output data is contiguous. - let ptr = outputs - .first() - .map(|output| output.to_raw().ptr()) - .expect("keystream is not empty"); - let size = self.blocks.len() * O::SIZE; - - let output = Vector::::from_raw(Slice::new_unchecked(ptr, size)); + let output = flatten_blocks(vm, outputs.iter().map(|block| block.to_raw()))?; Ok(output) } /// Returns `len` bytes of the keystream as a vector. - pub fn to_vector(&self, len: usize) -> Result, CipherError> { + pub fn to_vector( + &self, + vm: &mut dyn Vm, + len: usize, + ) -> Result, CipherError> { if len == 0 { return Err(CipherError::new("length must be greater than 0")); } else if self.blocks.is_empty() { @@ -219,14 +198,8 @@ where return Err(CipherError::new("length does not match keystream length")); } - let ptr = self - .blocks - .front() - .map(|block| block.output.to_raw().ptr()) - .expect("block count should be greater than 0"); - let size = block_count * O::SIZE; - - let mut keystream = Vector::::from_raw(Slice::new_unchecked(ptr, size)); + let mut keystream = + flatten_blocks(vm, self.blocks.iter().map(|block| block.output.to_raw()))?; keystream.truncate(len); Ok(keystream) @@ -272,6 +245,34 @@ where } } +fn flatten_blocks( + vm: &mut dyn Vm, + blocks: impl IntoIterator, +) -> Result, CipherError> { + use mpz_circuits::CircuitBuilder; + + let blocks = blocks.into_iter().collect::>(); + let len: usize = blocks.iter().map(|block| block.len()).sum(); + + let mut builder = CircuitBuilder::new(); + for _ in 0..len { + let i = builder.add_input(); + let o = builder.add_id_gate(i); + builder.add_output(o); + } + + let circuit = builder.build().expect("flatten circuit should be valid"); + + let mut builder = Call::builder(Arc::new(circuit)); + for block in blocks { + builder = builder.arg(block); + } + + let call = builder.build().map_err(CipherError::new)?; + + vm.call(call).map_err(CipherError::new) +} + /// A cipher error. #[derive(Debug, thiserror::Error)] #[error("{source}")] diff --git a/crates/components/deap/src/lib.rs b/crates/components/deap/src/lib.rs index b0e88a99d..03598cd24 100644 --- a/crates/components/deap/src/lib.rs +++ b/crates/components/deap/src/lib.rs @@ -12,7 +12,7 @@ use async_trait::async_trait; use mpz_common::Context; use mpz_core::bitvec::BitVec; use mpz_vm_core::{ - memory::{binary::Binary, DecodeFuture, Memory, Slice, View}, + memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View}, Call, Callable, Execute, Vm, VmError, }; use rangeset::{Difference, RangeSet, UnionMut}; @@ -85,10 +85,10 @@ impl Deap { self.zk.clone().try_lock_owned().unwrap() } - /// Translates a slice from the MPC VM address space to the ZK VM address + /// Translates a value from the MPC VM address space to the ZK VM address /// space. - pub fn translate_slice(&self, slice: Slice) -> Result { - self.memory_map.try_get(slice) + pub fn translate>(&self, value: T) -> Result { + self.memory_map.try_get(value.to_raw()).map(T::from_raw) } #[cfg(test)] diff --git a/crates/mpc-tls/src/record_layer/aead/aes_gcm.rs b/crates/mpc-tls/src/record_layer/aead/aes_gcm.rs index 72e4362ea..8e678dc75 100644 --- a/crates/mpc-tls/src/record_layer/aead/aes_gcm.rs +++ b/crates/mpc-tls/src/record_layer/aead/aes_gcm.rs @@ -33,7 +33,6 @@ enum State { input: Vector, keystream: Keystream, j0s: Vec<(CtrBlock, OneTimePadShared<[u8; 16]>)>, - output: Vector, ghash_key: OneTimePadShared<[u8; 16]>, ghash: Box, }, @@ -41,7 +40,6 @@ enum State { input: Vector, keystream: Keystream, j0s: Vec<(CtrBlock, OneTimePadShared<[u8; 16]>)>, - output: Vector, ghash: Arc, }, Error, @@ -125,13 +123,11 @@ impl MpcAesGcm { } let keystream = self.aes.alloc_keystream(vm, len)?; - let output = keystream.apply(vm, input)?; self.state = State::Setup { input, keystream, j0s, - output, ghash, ghash_key, }; @@ -162,7 +158,6 @@ impl MpcAesGcm { input, keystream, j0s, - output, mut ghash, ghash_key, } = self.state.take() @@ -178,7 +173,6 @@ impl MpcAesGcm { input, keystream, j0s, - output, ghash: Arc::from(ghash), }; @@ -202,10 +196,7 @@ impl MpcAesGcm { len: usize, ) -> Result<(Vector, Vector), AeadError> { let State::Ready { - input, - keystream, - output, - .. + input, keystream, .. } = &mut self.state else { return Err(AeadError::state( @@ -235,7 +226,7 @@ impl MpcAesGcm { let mut input = input.split_off(input.len() - padded_len); let keystream = keystream.consume(padded_len)?; - let mut output = output.split_off(output.len() - padded_len); + let mut output = keystream.apply(vm, input)?; // Assign counter block inputs. let mut ctr = START_CTR..; @@ -273,10 +264,7 @@ impl MpcAesGcm { len: usize, ) -> Result, AeadError> { let State::Ready { - input, - keystream, - output, - .. + input, keystream, .. } = &mut self.state else { return Err(AeadError::state("must be in ready state to take keystream")); @@ -301,11 +289,6 @@ impl MpcAesGcm { ))); } - // Drop the input and output text, we won't be needing them. - // This leaves them allocated but unassigned in the VM. - _ = input.split_off(input.len() - padded_len); - _ = output.split_off(output.len() - padded_len); - let keystream = keystream.consume(len)?; // Assign counter block inputs. @@ -314,7 +297,7 @@ impl MpcAesGcm { ctr.next().expect("range is unbounded").to_be_bytes() })?; - Ok(keystream.to_vector(len)?) + Ok(keystream.to_vector(vm, len)?) } /// Computes tags for the provided ciphertext. See diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index f7804e969..51b2b6bb7 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -20,7 +20,7 @@ use mpz_garble_core::Delta; use state::{Notarize, Prove}; use futures::{AsyncRead, AsyncWrite, TryFutureExt}; -use mpc_tls::{LeaderCtrl, MpcTlsLeader}; +use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys}; use rand::Rng; use serio::SinkExt; use std::sync::Arc; @@ -28,7 +28,12 @@ use tls_client::{ClientConnection, ServerName as TlsServerName}; use tls_client_async::{bind_client, TlsConnection}; use tls_core::msgs::enums::ContentType; use tlsn_common::{ - commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role, + commit::commit_records, + context::build_mt_context, + mux::attach_mux, + transcript::{Record, TlsTranscript}, + zk_aes::ZkAesCtr, + Role, }; use tlsn_core::{ connection::{ @@ -103,7 +108,9 @@ impl Prover { let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx); // Allocate resources for MPC-TLS in VM. - let keys = mpc_tls.alloc()?; + let mut keys = mpc_tls.alloc()?; + translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?; + // Allocate for committing to plaintext. let mut zk_aes = ZkAesCtr::new(Role::Prover); zk_aes.set_key(keys.server_write_key, keys.server_write_iv); @@ -205,6 +212,8 @@ impl Prover { { let mut vm = vm.try_lock().expect("VM should not be locked"); + translate_transcript(&mut data.transcript, &vm)?; + // Prove received plaintext. Prover drops the proof output, as they trust // themselves. _ = commit_records( @@ -414,3 +423,36 @@ impl ProverControl { .map_err(ProverError::from) } } + +/// Translates VM references to the ZK address space. +fn translate_keys(keys: &mut SessionKeys, vm: &Deap) -> Result<(), ProverError> { + keys.client_write_key = vm + .translate(keys.client_write_key) + .map_err(ProverError::mpc)?; + keys.client_write_iv = vm + .translate(keys.client_write_iv) + .map_err(ProverError::mpc)?; + keys.server_write_key = vm + .translate(keys.server_write_key) + .map_err(ProverError::mpc)?; + keys.server_write_iv = vm + .translate(keys.server_write_iv) + .map_err(ProverError::mpc)?; + + Ok(()) +} + +/// Translates VM references to the ZK address space. +fn translate_transcript( + transcript: &mut TlsTranscript, + vm: &Deap, +) -> Result<(), ProverError> { + for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut()) + { + if let Some(plaintext_ref) = plaintext_ref.as_mut() { + *plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?; + } + } + + Ok(()) +} diff --git a/crates/verifier/src/lib.rs b/crates/verifier/src/lib.rs index 34a44752d..69763dae9 100644 --- a/crates/verifier/src/lib.rs +++ b/crates/verifier/src/lib.rs @@ -16,7 +16,7 @@ pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderErr pub use error::VerifierError; use futures::{AsyncRead, AsyncWrite}; -use mpc_tls::{FollowerData, MpcTlsFollower}; +use mpc_tls::{FollowerData, MpcTlsFollower, SessionKeys}; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; @@ -24,8 +24,13 @@ use serio::stream::IoStreamExt; use state::{Notarize, Verify}; use tls_core::msgs::enums::ContentType; use tlsn_common::{ - commit::commit_records, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, - zk_aes::ZkAesCtr, Role, + commit::commit_records, + config::ProtocolConfig, + context::build_mt_context, + mux::attach_mux, + transcript::{Record, TlsTranscript}, + zk_aes::ZkAesCtr, + Role, }; use tlsn_core::{ attestation::{Attestation, AttestationConfig}, @@ -110,7 +115,9 @@ impl Verifier { let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx); // Allocate resources for MPC-TLS in VM. - let keys = mpc_tls.alloc()?; + let mut keys = mpc_tls.alloc()?; + translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?; + // Allocate for committing to plaintext. let mut zk_aes = ZkAesCtr::new(Role::Verifier); zk_aes.set_key(keys.server_write_key, keys.server_write_iv); @@ -222,6 +229,8 @@ impl Verifier { { let mut vm = vm.try_lock().expect("VM should not be locked"); + translate_transcript(&mut transcript, &vm)?; + // Prepare for the prover to prove received plaintext. let proof = commit_records( &mut (*vm.zk()), @@ -370,3 +379,39 @@ fn build_mpc_tls( ), ) } + +/// Translates VM references to the ZK address space. +fn translate_keys( + keys: &mut SessionKeys, + vm: &Deap, +) -> Result<(), VerifierError> { + keys.client_write_key = vm + .translate(keys.client_write_key) + .map_err(VerifierError::mpc)?; + keys.client_write_iv = vm + .translate(keys.client_write_iv) + .map_err(VerifierError::mpc)?; + keys.server_write_key = vm + .translate(keys.server_write_key) + .map_err(VerifierError::mpc)?; + keys.server_write_iv = vm + .translate(keys.server_write_iv) + .map_err(VerifierError::mpc)?; + + Ok(()) +} + +/// Translates VM references to the ZK address space. +fn translate_transcript( + transcript: &mut TlsTranscript, + vm: &Deap, +) -> Result<(), VerifierError> { + for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut()) + { + if let Some(plaintext_ref) = plaintext_ref.as_mut() { + *plaintext_ref = vm.translate(*plaintext_ref).map_err(VerifierError::mpc)?; + } + } + + Ok(()) +}