feat(deap): address space mapping (#809)

This commit is contained in:
sinu.eth
2025-05-13 18:38:39 +02:00
committed by GitHub
parent f900fc51cd
commit a8bf1026ca
2 changed files with 268 additions and 60 deletions

View File

@@ -4,13 +4,9 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
mod map;
use std::{mem, sync::Arc};
use async_trait::async_trait;
use mpz_common::Context;
@@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
role: Role,
mpc: Arc<Mutex<Mpc>>,
zk: Arc<Mutex<Zk>>,
/// Private inputs of the follower.
follower_inputs: RangeSet<usize>,
/// Mapping between the memories of the MPC and ZK VMs.
memory_map: map::MemoryMap,
/// Ranges of the follower's private inputs in the MPC VM.
follower_input_ranges: RangeSet<usize>,
/// Private inputs of the follower in the MPC VM.
follower_inputs: Vec<Slice>,
/// Outputs of the follower from the ZK VM. The references
/// correspond to the MPC VM.
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
/// Whether the memories of the two VMs are potentially desynchronized.
desync: AtomicBool,
}
impl<Mpc, Zk> Deap<Mpc, Zk> {
@@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
role,
mpc: Arc::new(Mutex::new(mpc)),
zk: Arc::new(Mutex::new(zk)),
follower_inputs: RangeSet::default(),
memory_map: map::MemoryMap::default(),
follower_input_ranges: RangeSet::default(),
follower_inputs: Vec::default(),
outputs: Vec::default(),
desync: AtomicBool::new(false),
}
}
@@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
/// Returns a mutable reference to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex is locked by another thread.
pub fn zk(&self) -> MutexGuard<'_, Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.try_lock().unwrap()
}
/// Returns an owned mutex guard to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex is locked by another thread.
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.clone().try_lock_owned().unwrap()
}
/// Translates a slice from the MPC VM address space to the ZK VM address
/// space.
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
self.memory_map.try_get(slice)
}
#[cfg(test)]
fn mpc(&self) -> MutexGuard<'_, Mpc> {
self.mpc.try_lock().unwrap()
@@ -124,18 +119,15 @@ where
// MACs.
let input_futs = self
.follower_inputs
.iter_ranges()
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
.iter()
.map(|&input| mpc.decode_raw(input))
.collect::<Result<Vec<_>, _>>()?;
mpc.execute_all(ctx).await?;
// Assign inputs to the ZK VM.
for (mut decode, input) in input_futs
.into_iter()
.zip(self.follower_inputs.iter_ranges())
{
let input = Slice::from_range_unchecked(input);
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
let input = self.memory_map.try_get(input)?;
// Follower has already assigned the inputs.
if let Role::Leader = self.role {
@@ -189,14 +181,12 @@ where
}
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
}
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;
self.zk.try_lock().unwrap().alloc_raw(size)?;
self.mpc.try_lock().unwrap().alloc_raw(size)
self.memory_map.insert(mpc_slice, zk_slice);
Ok(mpc_slice)
}
fn is_assigned_raw(&self, slice: Slice) -> bool {
@@ -204,11 +194,15 @@ where
}
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
self.zk
self.mpc
.try_lock()
.unwrap()
.assign_raw(slice, data.clone())?;
self.mpc.try_lock().unwrap().assign_raw(slice, data)
self.zk
.try_lock()
.unwrap()
.assign_raw(self.memory_map.try_get(slice)?, data)
}
fn is_committed_raw(&self, slice: Slice) -> bool {
@@ -217,10 +211,13 @@ where
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(Slice::from_range_unchecked(input))?;
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
)?;
}
self.mpc.try_lock().unwrap().commit_raw(slice)
@@ -231,7 +228,11 @@ where
}
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
let fut = self
.zk
.try_lock()
.unwrap()
.decode_raw(self.memory_map.try_get(slice)?)?;
self.outputs.push((slice, fut));
self.mpc.try_lock().unwrap().decode_raw(slice)
@@ -246,8 +247,11 @@ where
type Error = VmError;
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
self.mpc.try_lock().unwrap().mark_public_raw(slice)
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
self.zk
.try_lock()
.unwrap()
.mark_public_raw(self.memory_map.try_get(slice)?)
}
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
@@ -255,14 +259,15 @@ where
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
zk.mark_private_raw(slice)?;
mpc.mark_private_raw(slice)?;
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
}
Role::Follower => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_private_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -274,14 +279,15 @@ where
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_blind_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {
zk.mark_blind_raw(slice)?;
mpc.mark_blind_raw(slice)?;
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
}
}
@@ -295,14 +301,21 @@ where
Zk: Vm<Binary>,
{
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
let (circ, inputs) = call.clone().into_parts();
let mut builder = Call::builder(circ);
for input in inputs {
builder = builder.arg(self.memory_map.try_get(input)?);
}
self.zk.try_lock().unwrap().call_raw(call.clone())?;
self.mpc.try_lock().unwrap().call_raw(call)
let zk_call = builder.build().expect("call should be valid");
let output = self.mpc.try_lock().unwrap().call_raw(call)?;
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;
self.memory_map.insert(output, zk_output);
Ok(output)
}
}
@@ -463,6 +476,90 @@ mod tests {
assert_eq!(ct_leader, ct_follower);
}
#[tokio::test]
async fn test_deap_desync_memory() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
// Desynchronize the memories.
let _ = leader.zk().alloc_raw(1).unwrap();
let _ = follower.zk().alloc_raw(1).unwrap();
let (ct_leader, ct_follower) = futures::join!(
async {
let key: Array<U8, 16> = leader.alloc().unwrap();
let msg: Array<U8, 16> = leader.alloc().unwrap();
leader.mark_private(key).unwrap();
leader.mark_blind(msg).unwrap();
leader.assign(key, [42u8; 16]).unwrap();
leader.commit(key).unwrap();
leader.commit(msg).unwrap();
let ct: Array<U8, 16> = leader
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(msg)
.build()
.unwrap(),
)
.unwrap();
let ct = leader.decode(ct).unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.execute(&mut ctx_a).await.unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.finalize(&mut ctx_a).await.unwrap();
ct.await.unwrap()
},
async {
let key: Array<U8, 16> = follower.alloc().unwrap();
let msg: Array<U8, 16> = follower.alloc().unwrap();
follower.mark_blind(key).unwrap();
follower.mark_private(msg).unwrap();
follower.assign(msg, [69u8; 16]).unwrap();
follower.commit(key).unwrap();
follower.commit(msg).unwrap();
let ct: Array<U8, 16> = follower
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(msg)
.build()
.unwrap(),
)
.unwrap();
let ct = follower.decode(ct).unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.execute(&mut ctx_b).await.unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.finalize(&mut ctx_b).await.unwrap();
ct.await.unwrap()
}
);
assert_eq!(ct_leader, ct_follower);
}
// Tests that the leader can not use different inputs in each VM without
// detection by the follower.
#[tokio::test]

View File

@@ -0,0 +1,111 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::Subset;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]
pub(crate) struct MemoryMap {
mpc: Vec<Range<usize>>,
zk: Vec<Range<usize>>,
}
impl MemoryMap {
/// Inserts a new allocation into the map.
///
/// # Panics
///
/// - If the slices are not inserted in the order they are allocated.
/// - If the slices are not the same length.
pub(crate) fn insert(&mut self, mpc: Slice, zk: Slice) {
let mpc = mpc.to_range();
let zk = zk.to_range();
assert_eq!(mpc.len(), zk.len(), "slices must be the same length");
if let Some(last) = self.mpc.last() {
if last.end > mpc.start {
panic!("slices must be provided in ascending order");
}
}
self.mpc.push(mpc);
self.zk.push(zk);
}
/// Returns the corresponding allocation in the ZK VM.
pub(crate) fn try_get(&self, mpc: Slice) -> Result<Slice, VmError> {
let mpc_range = mpc.to_range();
let pos = match self
.mpc
.binary_search_by_key(&mpc_range.start, |range| range.start)
{
Ok(pos) => pos,
Err(0) => return Err(VmError::memory(format!("invalid memory slice: {mpc}"))),
Err(pos) => pos - 1,
};
let candidate = &self.mpc[pos];
if mpc_range.is_subset(candidate) {
let offset = mpc_range.start - candidate.start;
let start = self.zk[pos].start + offset;
let slice = Slice::from_range_unchecked(start..start + mpc_range.len());
Ok(slice)
} else {
Err(VmError::memory(format!("invalid memory slice: {mpc}")))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map() {
let mut map = MemoryMap::default();
map.insert(
Slice::from_range_unchecked(0..10),
Slice::from_range_unchecked(10..20),
);
// Range is fully contained.
assert_eq!(
map.try_get(Slice::from_range_unchecked(0..10)).unwrap(),
Slice::from_range_unchecked(10..20)
);
// Range is subset.
assert_eq!(
map.try_get(Slice::from_range_unchecked(1..9)).unwrap(),
Slice::from_range_unchecked(11..19)
);
// Range is not subset.
assert!(map.try_get(Slice::from_range_unchecked(0..11)).is_err());
// Insert another range.
map.insert(
Slice::from_range_unchecked(20..30),
Slice::from_range_unchecked(30..40),
);
assert_eq!(
map.try_get(Slice::from_range_unchecked(20..30)).unwrap(),
Slice::from_range_unchecked(30..40)
);
assert_eq!(
map.try_get(Slice::from_range_unchecked(21..29)).unwrap(),
Slice::from_range_unchecked(31..39)
);
assert!(map.try_get(Slice::from_range_unchecked(19..21)).is_err());
}
#[test]
#[should_panic]
fn test_map_length_mismatch() {
let mut map = MemoryMap::default();
map.insert(
Slice::from_range_unchecked(5..10),
Slice::from_range_unchecked(20..30),
);
}
}