mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 14:48:13 -05:00
feat(deap): address space mapping (#809)
This commit is contained in:
@@ -4,13 +4,9 @@
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::{
|
||||
mem,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
mod map;
|
||||
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_common::Context;
|
||||
@@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
|
||||
role: Role,
|
||||
mpc: Arc<Mutex<Mpc>>,
|
||||
zk: Arc<Mutex<Zk>>,
|
||||
/// Private inputs of the follower.
|
||||
follower_inputs: RangeSet<usize>,
|
||||
/// Mapping between the memories of the MPC and ZK VMs.
|
||||
memory_map: map::MemoryMap,
|
||||
/// Ranges of the follower's private inputs in the MPC VM.
|
||||
follower_input_ranges: RangeSet<usize>,
|
||||
/// Private inputs of the follower in the MPC VM.
|
||||
follower_inputs: Vec<Slice>,
|
||||
/// Outputs of the follower from the ZK VM. The references
|
||||
/// correspond to the MPC VM.
|
||||
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
|
||||
/// Whether the memories of the two VMs are potentially desynchronized.
|
||||
desync: AtomicBool,
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
@@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
role,
|
||||
mpc: Arc::new(Mutex::new(mpc)),
|
||||
zk: Arc::new(Mutex::new(zk)),
|
||||
follower_inputs: RangeSet::default(),
|
||||
memory_map: map::MemoryMap::default(),
|
||||
follower_input_ranges: RangeSet::default(),
|
||||
follower_inputs: Vec::default(),
|
||||
outputs: Vec::default(),
|
||||
desync: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
|
||||
/// Returns a mutable reference to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex is locked by another thread.
|
||||
pub fn zk(&self) -> MutexGuard<'_, Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.try_lock().unwrap()
|
||||
}
|
||||
|
||||
/// Returns an owned mutex guard to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex is locked by another thread.
|
||||
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.clone().try_lock_owned().unwrap()
|
||||
}
|
||||
|
||||
/// Translates a slice from the MPC VM address space to the ZK VM address
|
||||
/// space.
|
||||
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
|
||||
self.memory_map.try_get(slice)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mpc(&self) -> MutexGuard<'_, Mpc> {
|
||||
self.mpc.try_lock().unwrap()
|
||||
@@ -124,18 +119,15 @@ where
|
||||
// MACs.
|
||||
let input_futs = self
|
||||
.follower_inputs
|
||||
.iter_ranges()
|
||||
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
|
||||
.iter()
|
||||
.map(|&input| mpc.decode_raw(input))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
mpc.execute_all(ctx).await?;
|
||||
|
||||
// Assign inputs to the ZK VM.
|
||||
for (mut decode, input) in input_futs
|
||||
.into_iter()
|
||||
.zip(self.follower_inputs.iter_ranges())
|
||||
{
|
||||
let input = Slice::from_range_unchecked(input);
|
||||
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
|
||||
let input = self.memory_map.try_get(input)?;
|
||||
|
||||
// Follower has already assigned the inputs.
|
||||
if let Role::Leader = self.role {
|
||||
@@ -189,14 +181,12 @@ where
|
||||
}
|
||||
|
||||
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
}
|
||||
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
|
||||
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;
|
||||
|
||||
self.zk.try_lock().unwrap().alloc_raw(size)?;
|
||||
self.mpc.try_lock().unwrap().alloc_raw(size)
|
||||
self.memory_map.insert(mpc_slice, zk_slice);
|
||||
|
||||
Ok(mpc_slice)
|
||||
}
|
||||
|
||||
fn is_assigned_raw(&self, slice: Slice) -> bool {
|
||||
@@ -204,11 +194,15 @@ where
|
||||
}
|
||||
|
||||
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
|
||||
self.zk
|
||||
self.mpc
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign_raw(slice, data.clone())?;
|
||||
self.mpc.try_lock().unwrap().assign_raw(slice, data)
|
||||
|
||||
self.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign_raw(self.memory_map.try_get(slice)?, data)
|
||||
}
|
||||
|
||||
fn is_committed_raw(&self, slice: Slice) -> bool {
|
||||
@@ -217,10 +211,13 @@ where
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
zk.commit_raw(Slice::from_range_unchecked(input))?;
|
||||
zk.commit_raw(
|
||||
self.memory_map
|
||||
.try_get(Slice::from_range_unchecked(input))?,
|
||||
)?;
|
||||
}
|
||||
|
||||
self.mpc.try_lock().unwrap().commit_raw(slice)
|
||||
@@ -231,7 +228,11 @@ where
|
||||
}
|
||||
|
||||
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
|
||||
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
|
||||
let fut = self
|
||||
.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.decode_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.outputs.push((slice, fut));
|
||||
|
||||
self.mpc.try_lock().unwrap().decode_raw(slice)
|
||||
@@ -246,8 +247,11 @@ where
|
||||
type Error = VmError;
|
||||
|
||||
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
|
||||
self.mpc.try_lock().unwrap().mark_public_raw(slice)
|
||||
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
|
||||
self.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.mark_public_raw(self.memory_map.try_get(slice)?)
|
||||
}
|
||||
|
||||
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
@@ -255,14 +259,15 @@ where
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
zk.mark_private_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
|
||||
}
|
||||
Role::Follower => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,14 +279,15 @@ where
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
Role::Follower => {
|
||||
zk.mark_blind_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,14 +301,21 @@ where
|
||||
Zk: Vm<Binary>,
|
||||
{
|
||||
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
let (circ, inputs) = call.clone().into_parts();
|
||||
let mut builder = Call::builder(circ);
|
||||
|
||||
for input in inputs {
|
||||
builder = builder.arg(self.memory_map.try_get(input)?);
|
||||
}
|
||||
|
||||
self.zk.try_lock().unwrap().call_raw(call.clone())?;
|
||||
self.mpc.try_lock().unwrap().call_raw(call)
|
||||
let zk_call = builder.build().expect("call should be valid");
|
||||
|
||||
let output = self.mpc.try_lock().unwrap().call_raw(call)?;
|
||||
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;
|
||||
|
||||
self.memory_map.insert(output, zk_output);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,6 +476,90 @@ mod tests {
|
||||
assert_eq!(ct_leader, ct_follower);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deap_desync_memory() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta_zk, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
// Desynchronize the memories.
|
||||
let _ = leader.zk().alloc_raw(1).unwrap();
|
||||
let _ = follower.zk().alloc_raw(1).unwrap();
|
||||
|
||||
let (ct_leader, ct_follower) = futures::join!(
|
||||
async {
|
||||
let key: Array<U8, 16> = leader.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = leader.alloc().unwrap();
|
||||
|
||||
leader.mark_private(key).unwrap();
|
||||
leader.mark_blind(msg).unwrap();
|
||||
leader.assign(key, [42u8; 16]).unwrap();
|
||||
leader.commit(key).unwrap();
|
||||
leader.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = leader
|
||||
.call(
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(msg)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let ct = leader.decode(ct).unwrap();
|
||||
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.execute(&mut ctx_a).await.unwrap();
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.finalize(&mut ctx_a).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
},
|
||||
async {
|
||||
let key: Array<U8, 16> = follower.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = follower.alloc().unwrap();
|
||||
|
||||
follower.mark_blind(key).unwrap();
|
||||
follower.mark_private(msg).unwrap();
|
||||
follower.assign(msg, [69u8; 16]).unwrap();
|
||||
follower.commit(key).unwrap();
|
||||
follower.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = follower
|
||||
.call(
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(msg)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let ct = follower.decode(ct).unwrap();
|
||||
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.execute(&mut ctx_b).await.unwrap();
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.finalize(&mut ctx_b).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(ct_leader, ct_follower);
|
||||
}
|
||||
|
||||
// Tests that the leader can not use different inputs in each VM without
|
||||
// detection by the follower.
|
||||
#[tokio::test]
|
||||
|
||||
111
crates/components/deap/src/map.rs
Normal file
111
crates/components/deap/src/map.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_vm_core::{memory::Slice, VmError};
|
||||
use rangeset::Subset;
|
||||
|
||||
/// A mapping between the memories of the MPC and ZK VMs.
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct MemoryMap {
|
||||
mpc: Vec<Range<usize>>,
|
||||
zk: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
impl MemoryMap {
|
||||
/// Inserts a new allocation into the map.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the slices are not inserted in the order they are allocated.
|
||||
/// - If the slices are not the same length.
|
||||
pub(crate) fn insert(&mut self, mpc: Slice, zk: Slice) {
|
||||
let mpc = mpc.to_range();
|
||||
let zk = zk.to_range();
|
||||
|
||||
assert_eq!(mpc.len(), zk.len(), "slices must be the same length");
|
||||
|
||||
if let Some(last) = self.mpc.last() {
|
||||
if last.end > mpc.start {
|
||||
panic!("slices must be provided in ascending order");
|
||||
}
|
||||
}
|
||||
|
||||
self.mpc.push(mpc);
|
||||
self.zk.push(zk);
|
||||
}
|
||||
|
||||
/// Returns the corresponding allocation in the ZK VM.
|
||||
pub(crate) fn try_get(&self, mpc: Slice) -> Result<Slice, VmError> {
|
||||
let mpc_range = mpc.to_range();
|
||||
let pos = match self
|
||||
.mpc
|
||||
.binary_search_by_key(&mpc_range.start, |range| range.start)
|
||||
{
|
||||
Ok(pos) => pos,
|
||||
Err(0) => return Err(VmError::memory(format!("invalid memory slice: {mpc}"))),
|
||||
Err(pos) => pos - 1,
|
||||
};
|
||||
|
||||
let candidate = &self.mpc[pos];
|
||||
if mpc_range.is_subset(candidate) {
|
||||
let offset = mpc_range.start - candidate.start;
|
||||
let start = self.zk[pos].start + offset;
|
||||
let slice = Slice::from_range_unchecked(start..start + mpc_range.len());
|
||||
|
||||
Ok(slice)
|
||||
} else {
|
||||
Err(VmError::memory(format!("invalid memory slice: {mpc}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_map() {
|
||||
let mut map = MemoryMap::default();
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(0..10),
|
||||
Slice::from_range_unchecked(10..20),
|
||||
);
|
||||
|
||||
// Range is fully contained.
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(0..10)).unwrap(),
|
||||
Slice::from_range_unchecked(10..20)
|
||||
);
|
||||
// Range is subset.
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(1..9)).unwrap(),
|
||||
Slice::from_range_unchecked(11..19)
|
||||
);
|
||||
// Range is not subset.
|
||||
assert!(map.try_get(Slice::from_range_unchecked(0..11)).is_err());
|
||||
|
||||
// Insert another range.
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(20..30),
|
||||
Slice::from_range_unchecked(30..40),
|
||||
);
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(20..30)).unwrap(),
|
||||
Slice::from_range_unchecked(30..40)
|
||||
);
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(21..29)).unwrap(),
|
||||
Slice::from_range_unchecked(31..39)
|
||||
);
|
||||
assert!(map.try_get(Slice::from_range_unchecked(19..21)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_map_length_mismatch() {
|
||||
let mut map = MemoryMap::default();
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(5..10),
|
||||
Slice::from_range_unchecked(20..30),
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user