mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-08 21:08:04 -05:00
refactor(cipher): remove contiguous memory assumption (#864)
* refactor(cipher): remove contiguous memory assumption * fix mpc-tls and upstream crates
This commit is contained in:
@@ -54,18 +54,13 @@ impl ZkAesCtr {
|
||||
|
||||
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
let output = keystream.apply(vm, input)?;
|
||||
|
||||
match self.role {
|
||||
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
|
||||
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
|
||||
}
|
||||
|
||||
self.state = State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
};
|
||||
self.state = State::Ready { input, keystream };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -96,12 +91,7 @@ impl ZkAesCtr {
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
} = &mut self.state
|
||||
else {
|
||||
let State::Ready { input, keystream } = &mut self.state else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in ready state to encrypt",
|
||||
})?
|
||||
@@ -128,7 +118,7 @@ impl ZkAesCtr {
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = output.split_off(output.len() - padded_len);
|
||||
let mut output = keystream.apply(vm, input)?;
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
@@ -158,7 +148,6 @@ enum State {
|
||||
Ready {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
output: Vector<U8>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
|
||||
@@ -15,9 +15,9 @@ use async_trait::async_trait;
|
||||
use mpz_circuits::circuits::xor;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
|
||||
MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm};
|
||||
use mpz_vm_core::{prelude::*, Call, CallBuilder, CallError, Vm};
|
||||
use std::{collections::VecDeque, sync::Arc};
|
||||
|
||||
/// Provides computation of 2PC ciphers in counter and ECB mode.
|
||||
@@ -116,25 +116,7 @@ where
|
||||
O: Repr<Binary> + StaticSize<Binary> + Copy,
|
||||
{
|
||||
/// Creates a new keystream from the provided blocks.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the output of the keystream is not ordered and contiguous in
|
||||
/// memory.
|
||||
pub fn new(blocks: &[CtrBlock<N, C, O>]) -> Self {
|
||||
let mut pos = blocks
|
||||
.first()
|
||||
.map(|block| block.output.to_raw().ptr().as_usize())
|
||||
.unwrap_or(0);
|
||||
|
||||
for block in blocks {
|
||||
if block.output.to_raw().ptr().as_usize() != pos {
|
||||
panic!("output of keystream blocks must be ordered and contiguous in memory");
|
||||
}
|
||||
|
||||
pos += O::SIZE;
|
||||
}
|
||||
|
||||
Self {
|
||||
blocks: VecDeque::from_iter(blocks.iter().copied()),
|
||||
}
|
||||
@@ -177,7 +159,7 @@ where
|
||||
return Err(CipherError::new("no keystream material available"));
|
||||
}
|
||||
|
||||
let xor = Arc::new(xor(8 * self.block_size()));
|
||||
let xor = Arc::new(xor(self.block_size() * 8));
|
||||
let mut pos = 0;
|
||||
let mut outputs = Vec::with_capacity(self.blocks.len());
|
||||
for block in &self.blocks {
|
||||
@@ -194,20 +176,17 @@ where
|
||||
pos += self.block_size();
|
||||
}
|
||||
|
||||
// Calls were performed contiguously, so the output data is contiguous.
|
||||
let ptr = outputs
|
||||
.first()
|
||||
.map(|output| output.to_raw().ptr())
|
||||
.expect("keystream is not empty");
|
||||
let size = self.blocks.len() * O::SIZE;
|
||||
|
||||
let output = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
let output = flatten_blocks(vm, outputs.iter().map(|block| block.to_raw()))?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Returns `len` bytes of the keystream as a vector.
|
||||
pub fn to_vector(&self, len: usize) -> Result<Vector<U8>, CipherError> {
|
||||
pub fn to_vector(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Vector<U8>, CipherError> {
|
||||
if len == 0 {
|
||||
return Err(CipherError::new("length must be greater than 0"));
|
||||
} else if self.blocks.is_empty() {
|
||||
@@ -219,14 +198,8 @@ where
|
||||
return Err(CipherError::new("length does not match keystream length"));
|
||||
}
|
||||
|
||||
let ptr = self
|
||||
.blocks
|
||||
.front()
|
||||
.map(|block| block.output.to_raw().ptr())
|
||||
.expect("block count should be greater than 0");
|
||||
let size = block_count * O::SIZE;
|
||||
|
||||
let mut keystream = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
let mut keystream =
|
||||
flatten_blocks(vm, self.blocks.iter().map(|block| block.output.to_raw()))?;
|
||||
keystream.truncate(len);
|
||||
|
||||
Ok(keystream)
|
||||
@@ -272,6 +245,34 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn flatten_blocks(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
blocks: impl IntoIterator<Item = Slice>,
|
||||
) -> Result<Vector<U8>, CipherError> {
|
||||
use mpz_circuits::CircuitBuilder;
|
||||
|
||||
let blocks = blocks.into_iter().collect::<Vec<_>>();
|
||||
let len: usize = blocks.iter().map(|block| block.len()).sum();
|
||||
|
||||
let mut builder = CircuitBuilder::new();
|
||||
for _ in 0..len {
|
||||
let i = builder.add_input();
|
||||
let o = builder.add_id_gate(i);
|
||||
builder.add_output(o);
|
||||
}
|
||||
|
||||
let circuit = builder.build().expect("flatten circuit should be valid");
|
||||
|
||||
let mut builder = Call::builder(Arc::new(circuit));
|
||||
for block in blocks {
|
||||
builder = builder.arg(block);
|
||||
}
|
||||
|
||||
let call = builder.build().map_err(CipherError::new)?;
|
||||
|
||||
vm.call(call).map_err(CipherError::new)
|
||||
}
|
||||
|
||||
/// A cipher error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{source}")]
|
||||
|
||||
@@ -12,7 +12,7 @@ use async_trait::async_trait;
|
||||
use mpz_common::Context;
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Slice, View},
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use rangeset::{Difference, RangeSet, UnionMut};
|
||||
@@ -85,10 +85,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
self.zk.clone().try_lock_owned().unwrap()
|
||||
}
|
||||
|
||||
/// Translates a slice from the MPC VM address space to the ZK VM address
|
||||
/// Translates a value from the MPC VM address space to the ZK VM address
|
||||
/// space.
|
||||
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
|
||||
self.memory_map.try_get(slice)
|
||||
pub fn translate<T: Repr<Binary>>(&self, value: T) -> Result<T, VmError> {
|
||||
self.memory_map.try_get(value.to_raw()).map(T::from_raw)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -33,7 +33,6 @@ enum State {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
|
||||
output: Vector<U8>,
|
||||
ghash_key: OneTimePadShared<[u8; 16]>,
|
||||
ghash: Box<dyn Ghash + Send + Sync>,
|
||||
},
|
||||
@@ -41,7 +40,6 @@ enum State {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
|
||||
output: Vector<U8>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
},
|
||||
Error,
|
||||
@@ -125,13 +123,11 @@ impl MpcAesGcm {
|
||||
}
|
||||
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
let output = keystream.apply(vm, input)?;
|
||||
|
||||
self.state = State::Setup {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
ghash,
|
||||
ghash_key,
|
||||
};
|
||||
@@ -162,7 +158,6 @@ impl MpcAesGcm {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
mut ghash,
|
||||
ghash_key,
|
||||
} = self.state.take()
|
||||
@@ -178,7 +173,6 @@ impl MpcAesGcm {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
ghash: Arc::from(ghash),
|
||||
};
|
||||
|
||||
@@ -202,10 +196,7 @@ impl MpcAesGcm {
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), AeadError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
..
|
||||
input, keystream, ..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(AeadError::state(
|
||||
@@ -235,7 +226,7 @@ impl MpcAesGcm {
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = output.split_off(output.len() - padded_len);
|
||||
let mut output = keystream.apply(vm, input)?;
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
@@ -273,10 +264,7 @@ impl MpcAesGcm {
|
||||
len: usize,
|
||||
) -> Result<Vector<U8>, AeadError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
..
|
||||
input, keystream, ..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(AeadError::state("must be in ready state to take keystream"));
|
||||
@@ -301,11 +289,6 @@ impl MpcAesGcm {
|
||||
)));
|
||||
}
|
||||
|
||||
// Drop the input and output text, we won't be needing them.
|
||||
// This leaves them allocated but unassigned in the VM.
|
||||
_ = input.split_off(input.len() - padded_len);
|
||||
_ = output.split_off(output.len() - padded_len);
|
||||
|
||||
let keystream = keystream.consume(len)?;
|
||||
|
||||
// Assign counter block inputs.
|
||||
@@ -314,7 +297,7 @@ impl MpcAesGcm {
|
||||
ctr.next().expect("range is unbounded").to_be_bytes()
|
||||
})?;
|
||||
|
||||
Ok(keystream.to_vector(len)?)
|
||||
Ok(keystream.to_vector(vm, len)?)
|
||||
}
|
||||
|
||||
/// Computes tags for the provided ciphertext. See
|
||||
|
||||
@@ -20,7 +20,7 @@ use mpz_garble_core::Delta;
|
||||
use state::{Notarize, Prove};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::{LeaderCtrl, MpcTlsLeader};
|
||||
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
|
||||
use rand::Rng;
|
||||
use serio::SinkExt;
|
||||
use std::sync::Arc;
|
||||
@@ -28,7 +28,12 @@ use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{bind_client, TlsConnection};
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_common::{
|
||||
commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
|
||||
commit::commit_records,
|
||||
context::build_mt_context,
|
||||
mux::attach_mux,
|
||||
transcript::{Record, TlsTranscript},
|
||||
zk_aes::ZkAesCtr,
|
||||
Role,
|
||||
};
|
||||
use tlsn_core::{
|
||||
connection::{
|
||||
@@ -103,7 +108,9 @@ impl Prover<state::Initialized> {
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
|
||||
|
||||
// Allocate resources for MPC-TLS in VM.
|
||||
let keys = mpc_tls.alloc()?;
|
||||
let mut keys = mpc_tls.alloc()?;
|
||||
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
|
||||
|
||||
// Allocate for committing to plaintext.
|
||||
let mut zk_aes = ZkAesCtr::new(Role::Prover);
|
||||
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
|
||||
@@ -205,6 +212,8 @@ impl Prover<state::Setup> {
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
translate_transcript(&mut data.transcript, &vm)?;
|
||||
|
||||
// Prove received plaintext. Prover drops the proof output, as they trust
|
||||
// themselves.
|
||||
_ = commit_records(
|
||||
@@ -414,3 +423,36 @@ impl ProverControl {
|
||||
.map_err(ProverError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Translates VM references to the ZK address space.
|
||||
fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
|
||||
keys.client_write_key = vm
|
||||
.translate(keys.client_write_key)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.client_write_iv = vm
|
||||
.translate(keys.client_write_iv)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.server_write_key = vm
|
||||
.translate(keys.server_write_key)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.server_write_iv = vm
|
||||
.translate(keys.server_write_iv)
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Translates VM references to the ZK address space.
|
||||
fn translate_transcript<Mpc, Zk>(
|
||||
transcript: &mut TlsTranscript,
|
||||
vm: &Deap<Mpc, Zk>,
|
||||
) -> Result<(), ProverError> {
|
||||
for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
|
||||
{
|
||||
if let Some(plaintext_ref) = plaintext_ref.as_mut() {
|
||||
*plaintext_ref = vm.translate(*plaintext_ref).map_err(ProverError::mpc)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderErr
|
||||
pub use error::VerifierError;
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use mpc_tls::{FollowerData, MpcTlsFollower};
|
||||
use mpc_tls::{FollowerData, MpcTlsFollower, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble_core::Delta;
|
||||
@@ -24,8 +24,13 @@ use serio::stream::IoStreamExt;
|
||||
use state::{Notarize, Verify};
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_common::{
|
||||
commit::commit_records, config::ProtocolConfig, context::build_mt_context, mux::attach_mux,
|
||||
zk_aes::ZkAesCtr, Role,
|
||||
commit::commit_records,
|
||||
config::ProtocolConfig,
|
||||
context::build_mt_context,
|
||||
mux::attach_mux,
|
||||
transcript::{Record, TlsTranscript},
|
||||
zk_aes::ZkAesCtr,
|
||||
Role,
|
||||
};
|
||||
use tlsn_core::{
|
||||
attestation::{Attestation, AttestationConfig},
|
||||
@@ -110,7 +115,9 @@ impl Verifier<state::Initialized> {
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
|
||||
|
||||
// Allocate resources for MPC-TLS in VM.
|
||||
let keys = mpc_tls.alloc()?;
|
||||
let mut keys = mpc_tls.alloc()?;
|
||||
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
|
||||
|
||||
// Allocate for committing to plaintext.
|
||||
let mut zk_aes = ZkAesCtr::new(Role::Verifier);
|
||||
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
|
||||
@@ -222,6 +229,8 @@ impl Verifier<state::Setup> {
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
translate_transcript(&mut transcript, &vm)?;
|
||||
|
||||
// Prepare for the prover to prove received plaintext.
|
||||
let proof = commit_records(
|
||||
&mut (*vm.zk()),
|
||||
@@ -370,3 +379,39 @@ fn build_mpc_tls(
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// Translates VM references to the ZK address space.
|
||||
fn translate_keys<Mpc, Zk>(
|
||||
keys: &mut SessionKeys,
|
||||
vm: &Deap<Mpc, Zk>,
|
||||
) -> Result<(), VerifierError> {
|
||||
keys.client_write_key = vm
|
||||
.translate(keys.client_write_key)
|
||||
.map_err(VerifierError::mpc)?;
|
||||
keys.client_write_iv = vm
|
||||
.translate(keys.client_write_iv)
|
||||
.map_err(VerifierError::mpc)?;
|
||||
keys.server_write_key = vm
|
||||
.translate(keys.server_write_key)
|
||||
.map_err(VerifierError::mpc)?;
|
||||
keys.server_write_iv = vm
|
||||
.translate(keys.server_write_iv)
|
||||
.map_err(VerifierError::mpc)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Translates VM references to the ZK address space.
|
||||
fn translate_transcript<Mpc, Zk>(
|
||||
transcript: &mut TlsTranscript,
|
||||
vm: &Deap<Mpc, Zk>,
|
||||
) -> Result<(), VerifierError> {
|
||||
for Record { plaintext_ref, .. } in transcript.sent.iter_mut().chain(transcript.recv.iter_mut())
|
||||
{
|
||||
if let Some(plaintext_ref) = plaintext_ref.as_mut() {
|
||||
*plaintext_ref = vm.translate(*plaintext_ref).map_err(VerifierError::mpc)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user