diff --git a/mpc/mpc-aio/benches/deap.rs b/mpc/mpc-aio/benches/deap.rs index efc48ff53..9ff428fb4 100644 --- a/mpc/mpc-aio/benches/deap.rs +++ b/mpc/mpc-aio/benches/deap.rs @@ -1,14 +1,18 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use mpc_aio::protocol::garble::exec::deap::mock_deap_pair; use mpc_circuits::{Circuit, WireGroup, AES_128_REVERSE}; -use mpc_core::garble::FullInputSet; +use mpc_core::garble::{exec::deap::DEAPConfigBuilder, FullInputSet}; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use std::sync::Arc; async fn bench_deap(circ: Arc) { let mut rng = ChaCha12Rng::seed_from_u64(0); - let (leader, follower) = mock_deap_pair(circ.clone()); + let config = DEAPConfigBuilder::default() + .id("bench".to_string()) + .build() + .unwrap(); + let (leader, follower) = mock_deap_pair(config, circ.clone()); let leader_input = circ.input(0).unwrap().to_value(vec![0u8; 16]).unwrap(); let follower_input = circ.input(1).unwrap().to_value(vec![0u8; 16]).unwrap(); diff --git a/mpc/mpc-aio/benches/dualex.rs b/mpc/mpc-aio/benches/dualex.rs index 2a96184e2..8d20a7602 100644 --- a/mpc/mpc-aio/benches/dualex.rs +++ b/mpc/mpc-aio/benches/dualex.rs @@ -1,14 +1,18 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use mpc_aio::protocol::garble::exec::dual::mock_dualex_pair; use mpc_circuits::{Circuit, WireGroup, AES_128_REVERSE}; -use mpc_core::garble::FullInputSet; +use mpc_core::garble::{exec::dual::DualExConfigBuilder, FullInputSet}; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use std::sync::Arc; async fn bench_dualex(circ: Arc) { let mut rng = ChaCha12Rng::seed_from_u64(0); - let (leader, follower) = mock_dualex_pair(circ.clone()); + let config = DualExConfigBuilder::default() + .id("bench".to_string()) + .build() + .unwrap(); + let (leader, follower) = mock_dualex_pair(config, circ.clone()); let leader_input = circ.input(0).unwrap().to_value(vec![0u8; 16]).unwrap(); let follower_input = circ.input(1).unwrap().to_value(vec![0u8; 16]).unwrap(); diff --git a/mpc/mpc-aio/src/protocol/garble/exec/deap/follower.rs b/mpc/mpc-aio/src/protocol/garble/exec/deap/follower.rs index b1deddecc..39fab2834 100644 --- a/mpc/mpc-aio/src/protocol/garble/exec/deap/follower.rs +++ b/mpc/mpc-aio/src/protocol/garble/exec/deap/follower.rs @@ -1,16 +1,20 @@ -use std::sync::Arc; +use std::{marker::PhantomData, sync::Arc}; use crate::protocol::{ garble::{Evaluator, GCError, GarbleChannel, GarbleMessage, Generator}, - ot::{ObliviousReceive, ObliviousReveal, ObliviousSend}, + ot::{OTFactoryError, ObliviousReceive, ObliviousReveal, ObliviousSend}, }; use futures::{SinkExt, StreamExt}; use mpc_circuits::{Circuit, Input, InputValue, OutputValue}; -use mpc_core::garble::{ - exec::deap as core, gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, - FullInputSet, GarbledCircuit, +use mpc_core::{ + garble::{ + exec::deap::{self as core, DEAPConfig}, + gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, FullInputSet, + GarbledCircuit, + }, + ot::config::{OTReceiverConfig, OTSenderConfig}, }; -use utils_aio::expect_msg_or_err; +use utils_aio::{expect_msg_or_err, factory::AsyncFactory}; use super::setup_inputs_with; @@ -21,65 +25,75 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::LabelSetup {} - impl Sealed for super::Executed {} + impl Sealed for super::LabelSetup {} + impl Sealed for super::Executed {} } pub trait State: sealed::Sealed {} pub struct Initialized; - pub struct LabelSetup { + pub struct LabelSetup { pub(crate) gen_labels: FullInputSet, pub(crate) ev_labels: ActiveInputSet, + pub(crate) label_sender: Option, } - pub struct Executed { + pub struct Executed { pub(super) core: core::DEAPFollower, + pub(crate) label_sender: Option, } impl State for Initialized {} - impl State for LabelSetup {} - impl State for Executed {} + impl State for LabelSetup {} + impl State for Executed {} } use state::*; -pub struct DEAPFollower +pub struct DEAPFollower where S: State, - B: Generator + Evaluator, - LS: ObliviousSend + ObliviousReveal, - LR: ObliviousReceive, { + config: DEAPConfig, state: S, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, + label_sender_factory: LSF, + label_receiver_factory: LRF, + + _label_sender: PhantomData, + _label_receiver: PhantomData, } -impl DEAPFollower +impl DEAPFollower where B: Generator + Evaluator + Send, + LSF: AsyncFactory + Send, + LRF: AsyncFactory + Send, LS: ObliviousSend + ObliviousReveal + Send, LR: ObliviousReceive + Send, { pub fn new( + config: DEAPConfig, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, + label_sender_factory: LSF, + label_receiver_factory: LRF, ) -> Self { DEAPFollower { + config, state: Initialized, circ, channel, backend, - label_sender, - label_receiver, + label_sender_factory, + label_receiver_factory, + + _label_sender: PhantomData, + _label_receiver: PhantomData, } } @@ -98,11 +112,16 @@ where ot_send_inputs: Vec, ot_receive_inputs: Vec, cached_labels: Vec, - ) -> Result, GCError> { - let (gen_labels, ev_labels) = setup_inputs_with( + ) -> Result, B, LSF, LRF, LS, LR>, GCError> { + let label_sender_id = format!("{}/ot/1", self.config.id()); + let label_receiver_id = format!("{}/ot/0", self.config.id()); + + let ((gen_labels, ev_labels), (label_sender, _)) = setup_inputs_with( + label_sender_id, + label_receiver_id, &mut self.channel, - self.label_sender.as_mut(), - self.label_receiver.as_mut(), + &mut self.label_sender_factory, + &mut self.label_receiver_factory, gen_labels, gen_inputs, ot_send_inputs, @@ -112,24 +131,28 @@ where .await?; Ok(DEAPFollower { + config: self.config, state: LabelSetup { gen_labels, ev_labels, + label_sender, }, circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + + _label_sender: PhantomData, + _label_receiver: PhantomData, }) } } -impl DEAPFollower +impl DEAPFollower, B, LSF, LRF, LS, LR> where B: Generator + Evaluator + Send, LS: ObliviousSend + ObliviousReveal + Send, - LR: ObliviousReceive + Send, { /// Execute first phase of the protocol, returning the _purported_ circuit output. /// @@ -137,7 +160,13 @@ where /// validated in the next phase. pub async fn execute( self, - ) -> Result<(Vec, DEAPFollower), GCError> { + ) -> Result< + ( + Vec, + DEAPFollower, B, LSF, LRF, LS, LR>, + ), + GCError, + > { // Discard the summary let (output, _, follower) = self.execute_and_summarize().await?; Ok((output, follower)) @@ -154,7 +183,7 @@ where ( Vec, GarbledCircuit, - DEAPFollower, + DEAPFollower, B, LSF, LRF, LS, LR>, ), GCError, > { @@ -213,18 +242,25 @@ where purported_output, evaluated_summary, DEAPFollower { - state: Executed { core: follower }, + config: self.config, + state: Executed { + core: follower, + label_sender: self.state.label_sender, + }, circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + + _label_sender: PhantomData, + _label_receiver: PhantomData, }, )) } } -impl DEAPFollower +impl DEAPFollower, B, LSF, LRF, LS, LR> where B: Generator + Evaluator + Send, LS: ObliviousSend + ObliviousReveal + Send, @@ -248,7 +284,7 @@ where .await?; // Open our OTs to leader - if let Some(label_sender) = self.label_sender.take() { + if let Some(label_sender) = self.state.label_sender.take() { label_sender.reveal().await?; } diff --git a/mpc/mpc-aio/src/protocol/garble/exec/deap/leader.rs b/mpc/mpc-aio/src/protocol/garble/exec/deap/leader.rs index 266179c27..494ee4c5d 100644 --- a/mpc/mpc-aio/src/protocol/garble/exec/deap/leader.rs +++ b/mpc/mpc-aio/src/protocol/garble/exec/deap/leader.rs @@ -1,16 +1,20 @@ -use std::sync::Arc; +use std::{marker::PhantomData, sync::Arc}; use crate::protocol::{ garble::{Compressor, Evaluator, GCError, GarbleChannel, GarbleMessage, Generator, Validator}, - ot::{ObliviousReceive, ObliviousSend, ObliviousVerify}, + ot::{OTFactoryError, ObliviousReceive, ObliviousSend, ObliviousVerify}, }; -use futures::{future::ready, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use mpc_circuits::{Circuit, Input, InputValue, OutputValue, WireGroup}; -use mpc_core::garble::{ - exec::deap as core, gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, - FullInputSet, GarbledCircuit, +use mpc_core::{ + garble::{ + exec::deap::{self as core, DEAPConfig}, + gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, FullInputSet, + GarbledCircuit, + }, + ot::config::{OTReceiverConfig, OTSenderConfig}, }; -use utils_aio::expect_msg_or_err; +use utils_aio::{expect_msg_or_err, factory::AsyncFactory}; use super::setup_inputs_with; @@ -21,28 +25,30 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::LabelSetup {} - impl Sealed for super::Executed {} + impl Sealed for super::LabelSetup {} + impl Sealed for super::Executed {} } pub trait State: sealed::Sealed {} pub struct Initialized; - pub struct LabelSetup { + pub struct LabelSetup { pub(crate) gen_labels: FullInputSet, pub(crate) ev_labels: ActiveInputSet, pub(crate) input_state: InputState, + pub(crate) label_receiver: Option, } - pub struct Executed { + pub struct Executed { pub(super) core: core::DEAPLeader, pub(crate) input_state: InputState, + pub(crate) label_receiver: Option, } impl State for Initialized {} - impl State for LabelSetup {} - impl State for Executed {} + impl State for LabelSetup {} + impl State for Executed {} pub(crate) struct InputState { pub(crate) ot_receive_inputs: Vec, @@ -51,41 +57,48 @@ pub mod state { use state::*; -pub struct DEAPLeader +pub struct DEAPLeader where S: State, - B: Generator + Evaluator, - LS: ObliviousSend, - LR: ObliviousReceive + ObliviousVerify, { + config: DEAPConfig, state: S, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, + label_sender_factory: LSF, + label_receiver_factory: LRF, + + _label_sender: PhantomData, + _label_receiver: PhantomData, } -impl DEAPLeader +impl DEAPLeader where B: Generator + Evaluator + Compressor + Validator + Send, + LSF: AsyncFactory + Send, + LRF: AsyncFactory + Send, LS: ObliviousSend + Send, LR: ObliviousReceive + ObliviousVerify + Send, { pub fn new( + config: DEAPConfig, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, - ) -> DEAPLeader { + label_sender_factory: LSF, + label_receiver_factory: LRF, + ) -> DEAPLeader { DEAPLeader { + config, state: Initialized, circ, channel, backend, - label_sender, - label_receiver, + label_sender_factory, + label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, } } @@ -104,11 +117,16 @@ where ot_send_inputs: Vec, ot_receive_inputs: Vec, cached_labels: Vec, - ) -> Result, GCError> { - let (gen_labels, ev_labels) = setup_inputs_with( + ) -> Result, B, LSF, LRF, LS, LR>, GCError> { + let label_sender_id = format!("{}/ot/0", self.config.id()); + let label_receiver_id = format!("{}/ot/1", self.config.id()); + + let ((gen_labels, ev_labels), (_, label_receiver)) = setup_inputs_with( + label_sender_id, + label_receiver_id, &mut self.channel, - self.label_sender.as_mut(), - self.label_receiver.as_mut(), + &mut self.label_sender_factory, + &mut self.label_receiver_factory, gen_labels, gen_inputs, ot_send_inputs, @@ -118,6 +136,7 @@ where .await?; Ok(DEAPLeader { + config: self.config, state: LabelSetup { gen_labels, ev_labels, @@ -127,26 +146,34 @@ where .map(|v| v.group().clone()) .collect(), }, + label_receiver, }, circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, }) } } -impl DEAPLeader +impl DEAPLeader, B, LSF, LRF, LS, LR> where B: Generator + Evaluator + Compressor + Validator + Send, - LS: ObliviousSend + Send, LR: ObliviousReceive + ObliviousVerify + Send, { /// Execute first phase of the protocol, returning the authenticated output. pub async fn execute( self, - ) -> Result<(Vec, DEAPLeader), GCError> { + ) -> Result< + ( + Vec, + DEAPLeader, B, LSF, LRF, LS, LR>, + ), + GCError, + > { // Discard summary let (output, _, leader) = self.execute_and_summarize().await?; Ok((output, leader)) @@ -162,7 +189,7 @@ where ( Vec, GarbledCircuit, - DEAPLeader, + DEAPLeader, B, LSF, LRF, LS, LR>, ), GCError, > { @@ -222,21 +249,25 @@ where output, gc_evaluated_summary, DEAPLeader { + config: self.config, state: Executed { core: leader, input_state: self.state.input_state, + label_receiver: self.state.label_receiver, }, circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, }, )) } } -impl DEAPLeader +impl DEAPLeader, B, LSF, LRF, LS, LR> where B: Generator + Evaluator + Compressor + Validator + Send, LS: ObliviousSend + Send, @@ -268,27 +299,32 @@ where let gc_validate_fut = self.backend.validate_compressed(gc_cmp, opening); // If we did not receive any inputs via OT, we can skip the OT validation - let labels_validate_fut = if self.state.input_state.ot_receive_inputs.is_empty() { - Box::pin(ready(Ok(()))) - } else { - let Some(label_receiver) = self.label_receiver.take() else { - return Err(GCError::MissingOTReceiver); - }; + let labels_validate_fut = async move { + if self.state.input_state.ot_receive_inputs.is_empty() { + Ok(()) + } else { + let Some(label_receiver) = self.state.label_receiver.take() else { + return Err(GCError::MissingOTReceiver); + }; - let ot_received = self - .state - .input_state - .ot_receive_inputs - .iter() - .map(|input| { - input_labels - .get(input.index()) - .expect("Input id should be valid") - }) - .cloned() - .collect::>(); + let ot_received = self + .state + .input_state + .ot_receive_inputs + .iter() + .map(|input| { + input_labels + .get(input.index()) + .expect("Input id should be valid") + }) + .cloned() + .collect::>(); - label_receiver.verify(ot_received) + label_receiver + .verify(ot_received) + .await + .map_err(GCError::from) + } }; let (gc_validate_result, labels_validate_result) = diff --git a/mpc/mpc-aio/src/protocol/garble/exec/deap/mod.rs b/mpc/mpc-aio/src/protocol/garble/exec/deap/mod.rs index b00d96827..011186874 100644 --- a/mpc/mpc-aio/src/protocol/garble/exec/deap/mod.rs +++ b/mpc/mpc-aio/src/protocol/garble/exec/deap/mod.rs @@ -14,41 +14,55 @@ mod mock { use super::*; use crate::protocol::{ garble::backend::RayonBackend, - ot::mock::{mock_ot_pair, MockOTReceiver, MockOTSender}, + ot::mock::{MockOTFactory, MockOTReceiver, MockOTSender}, }; use mpc_circuits::Circuit; - use mpc_core::{msgs::garble::GarbleMessage, Block}; + use mpc_core::{garble::exec::deap::DEAPConfig, msgs::garble::GarbleMessage, Block}; use utils_aio::duplex::DuplexChannel; - pub type MockDEAPLeader = - DEAPLeader, MockOTReceiver>; - pub type MockDEAPFollower = - DEAPFollower, MockOTReceiver>; + pub type MockDEAPLeader = DEAPLeader< + S, + RayonBackend, + MockOTFactory, + MockOTFactory, + MockOTSender, + MockOTReceiver, + >; + pub type MockDEAPFollower = DEAPFollower< + S, + RayonBackend, + MockOTFactory, + MockOTFactory, + MockOTSender, + MockOTReceiver, + >; pub fn mock_deap_pair( + config: DEAPConfig, circ: Arc, ) -> ( MockDEAPLeader, MockDEAPFollower, ) { let (leader_channel, follower_channel) = DuplexChannel::::new(); - let (leader_sender, follower_receiver) = mock_ot_pair(); - let (follower_sender, leader_receiver) = mock_ot_pair(); + let ot_factory = MockOTFactory::new(); let leader = DEAPLeader::new( + config.clone(), circ.clone(), Box::new(leader_channel), RayonBackend, - Some(leader_sender), - Some(leader_receiver), + ot_factory.clone(), + ot_factory.clone(), ); let follower = DEAPFollower::new( + config, circ, Box::new(follower_channel), RayonBackend, - Some(follower_sender), - Some(follower_receiver), + ot_factory.clone(), + ot_factory, ); (leader, follower) @@ -62,15 +76,19 @@ pub use mock::mock_deap_pair; mod tests { use super::*; use mpc_circuits::{Circuit, WireGroup, ADDER_64}; - use mpc_core::garble::FullInputSet; + use mpc_core::garble::{exec::deap::DEAPConfigBuilder, FullInputSet}; use rand_chacha::ChaCha12Rng; use rand_core::SeedableRng; #[tokio::test] async fn test_deap() { + let config = DEAPConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); let mut rng = ChaCha12Rng::seed_from_u64(0); let circ = Circuit::load_bytes(ADDER_64).unwrap(); - let (leader, follower) = mock_deap_pair(circ.clone()); + let (leader, follower) = mock_deap_pair(config, circ.clone()); let leader_input = circ.input(0).unwrap().to_value(1u64).unwrap(); let follower_input = circ.input(1).unwrap().to_value(2u64).unwrap(); diff --git a/mpc/mpc-aio/src/protocol/garble/exec/dual.rs b/mpc/mpc-aio/src/protocol/garble/exec/dual/mod.rs similarity index 73% rename from mpc/mpc-aio/src/protocol/garble/exec/dual.rs rename to mpc/mpc-aio/src/protocol/garble/exec/dual/mod.rs index 8a4cf5eb6..94a9978c0 100644 --- a/mpc/mpc-aio/src/protocol/garble/exec/dual.rs +++ b/mpc/mpc-aio/src/protocol/garble/exec/dual/mod.rs @@ -6,19 +6,25 @@ //! malicious. Such leakage, however, will be detected by the [`DualExFollower`] during the //! equality check. -use std::sync::Arc; +use std::{marker::PhantomData, sync::Arc}; use crate::protocol::{ garble::{Evaluator, GCError, GarbleChannel, GarbleMessage, Generator}, - ot::{ObliviousReceive, ObliviousSend}, + ot::{OTFactoryError, ObliviousReceive, ObliviousSend}, }; -use futures::{future::ready, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use mpc_circuits::{Circuit, Input, InputValue, OutputValue, WireGroup}; -use mpc_core::garble::{ - exec::dual as core, gc_state, ActiveEncodedInput, ActiveInputSet, Error as CoreError, - FullEncodedInput, FullInputSet, GarbledCircuit, +use mpc_core::{ + garble::{ + exec::dual::{self as core, DualExConfig}, + gc_state, ActiveEncodedInput, ActiveInputSet, Error as CoreError, FullEncodedInput, + FullInputSet, GarbledCircuit, + }, + ot::config::{ + OTReceiverConfig, OTReceiverConfigBuilder, OTSenderConfig, OTSenderConfigBuilder, + }, }; -use utils_aio::expect_msg_or_err; +use utils_aio::{expect_msg_or_err, factory::AsyncFactory}; mod state { use super::*; @@ -45,42 +51,49 @@ mod state { use state::*; -pub struct DualExLeader +pub struct DualExLeader where S: State, - B: Generator + Evaluator, - LS: ObliviousSend, - LR: ObliviousReceive, { + config: DualExConfig, state: S, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, + label_sender_factory: LSF, + label_receiver_factory: LRF, + + _label_sender: PhantomData, + _label_receiver: PhantomData, } -impl DualExLeader +impl DualExLeader where B: Generator + Evaluator + Send, + LSF: AsyncFactory + Send, + LRF: AsyncFactory + Send, LS: ObliviousSend + Send, LR: ObliviousReceive + Send, { /// Create a new DualExLeader pub fn new( + config: DualExConfig, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, - ) -> DualExLeader { + label_sender_factory: LSF, + label_receiver_factory: LRF, + ) -> DualExLeader { DualExLeader { + config, state: Initialized, circ, channel, backend, - label_sender, - label_receiver, + label_sender_factory, + label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, } } @@ -99,11 +112,16 @@ where ot_send_inputs: Vec, ot_receive_inputs: Vec, cached_labels: Vec, - ) -> Result, GCError> { - let (gen_labels, ev_labels) = setup_inputs_with( + ) -> Result, GCError> { + let label_sender_id = format!("{}/ot/0", self.config.id()); + let label_receiver_id = format!("{}/ot/1", self.config.id()); + + let ((gen_labels, ev_labels), _) = setup_inputs_with( + label_sender_id, + label_receiver_id, &mut self.channel, - self.label_sender.as_mut(), - self.label_receiver.as_mut(), + &mut self.label_sender_factory, + &mut self.label_receiver_factory, gen_labels, gen_inputs, ot_send_inputs, @@ -113,6 +131,7 @@ where .await?; Ok(DualExLeader { + config: self.config, state: LabelSetup { gen_labels, ev_labels, @@ -120,17 +139,17 @@ where circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, }) } } -impl DualExLeader +impl DualExLeader where B: Generator + Evaluator + Send, - LS: ObliviousSend + Send, - LR: ObliviousReceive + Send, { /// Execute dual execution protocol /// @@ -261,42 +280,49 @@ where } } -pub struct DualExFollower +pub struct DualExFollower where S: State, - B: Generator + Evaluator, - LS: ObliviousSend, - LR: ObliviousReceive, { + config: DualExConfig, state: S, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, + label_sender_factory: LSF, + label_receiver_factory: LRF, + + _label_sender: PhantomData, + _label_receiver: PhantomData, } -impl DualExFollower +impl DualExFollower where B: Generator + Evaluator + Send, + LSF: AsyncFactory + Send, + LRF: AsyncFactory + Send, LS: ObliviousSend + Send, LR: ObliviousReceive + Send, { /// Create a new DualExFollower pub fn new( + config: DualExConfig, circ: Arc, channel: GarbleChannel, backend: B, - label_sender: Option, - label_receiver: Option, - ) -> DualExFollower { + label_sender_factory: LSF, + label_receiver_factory: LRF, + ) -> DualExFollower { DualExFollower { + config, state: Initialized, circ, channel, backend, - label_sender, - label_receiver, + label_sender_factory, + label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, } } @@ -315,11 +341,16 @@ where ot_send_inputs: Vec, ot_receive_inputs: Vec, cached_labels: Vec, - ) -> Result, GCError> { - let (gen_labels, ev_labels) = setup_inputs_with( + ) -> Result, GCError> { + let label_sender_id = format!("{}/ot/1", self.config.id()); + let label_receiver_id = format!("{}/ot/0", self.config.id()); + + let ((gen_labels, ev_labels), _) = setup_inputs_with( + label_sender_id, + label_receiver_id, &mut self.channel, - self.label_sender.as_mut(), - self.label_receiver.as_mut(), + &mut self.label_sender_factory, + &mut self.label_receiver_factory, gen_labels, gen_inputs, ot_send_inputs, @@ -329,6 +360,7 @@ where .await?; Ok(DualExFollower { + config: self.config, state: LabelSetup { gen_labels, ev_labels, @@ -336,17 +368,17 @@ where circ: self.circ, channel: self.channel, backend: self.backend, - label_sender: self.label_sender, - label_receiver: self.label_receiver, + label_sender_factory: self.label_sender_factory, + label_receiver_factory: self.label_receiver_factory, + _label_sender: PhantomData, + _label_receiver: PhantomData, }) } } -impl DualExFollower +impl DualExFollower where B: Generator + Evaluator + Send, - LS: ObliviousSend + Send, - LR: ObliviousReceive + Send, { /// Execute dual execution protocol /// @@ -479,17 +511,21 @@ where } /// Set up input labels by exchanging directly and via oblivious transfer. -pub async fn setup_inputs_with( +pub async fn setup_inputs_with( + label_sender_id: String, + label_receiver_id: String, channel: &mut GarbleChannel, - label_sender: Option<&mut LS>, - label_receiver: Option<&mut LR>, + label_sender_factory: &mut LSF, + label_receiver_factory: &mut LRF, gen_labels: FullInputSet, gen_inputs: Vec, ot_send_inputs: Vec, ot_receive_inputs: Vec, cached_labels: Vec, -) -> Result<(FullInputSet, ActiveInputSet), GCError> +) -> Result<((FullInputSet, ActiveInputSet), (Option, Option)), GCError> where + LSF: AsyncFactory + Send, + LRF: AsyncFactory + Send, LS: ObliviousSend + Send, LR: ObliviousReceive + Send, { @@ -514,12 +550,25 @@ where // Concurrently execute oblivious transfers and direct label sending // If there are no labels to be sent via OT, we can skip the OT protocol - let ot_send_fut = match label_sender { - Some(label_sender) if ot_send_labels.len() > 0 => label_sender.send(ot_send_labels), - None if ot_send_labels.len() > 0 => { - return Err(GCError::MissingOTSender); + let ot_send_fut = async move { + if ot_send_labels.len() > 0 { + let count = ot_send_labels.iter().map(|labels| labels.len()).sum(); + + let sender_config = OTSenderConfigBuilder::default() + .count(count) + .build() + .expect("OTSenderConfig should be valid"); + + let mut label_sender = label_sender_factory + .create(label_sender_id, sender_config) + .await?; + + let _ = label_sender.send(ot_send_labels).await?; + + Result::<_, GCError>::Ok(Some(label_sender)) + } else { + Result::<_, GCError>::Ok(None) } - _ => Box::pin(ready(Ok(()))), }; let direct_send_fut = channel.send(GarbleMessage::InputLabels( @@ -530,22 +579,33 @@ where )); // If there are no labels to be received via OT, we can skip the OT protocol - let ot_receive_fut = match label_receiver { - Some(label_receiver) if ot_receive_inputs.len() > 0 => { - label_receiver.receive(ot_receive_inputs) + let ot_receive_fut = async move { + if ot_receive_inputs.len() > 0 { + let count = ot_receive_inputs.iter().map(|input| input.len()).sum(); + + let receiver_config = OTReceiverConfigBuilder::default() + .count(count) + .build() + .expect("OTReceiverConfig should be valid"); + + let mut label_receiver = label_receiver_factory + .create(label_receiver_id, receiver_config) + .await?; + + let ot_receive_labels = label_receiver.receive(ot_receive_inputs).await?; + + Result::<_, GCError>::Ok((ot_receive_labels, Some(label_receiver))) + } else { + Result::<_, GCError>::Ok((vec![], None)) } - None if ot_receive_inputs.len() > 0 => { - return Err(GCError::MissingOTReceiver); - } - _ => Box::pin(ready(Ok(vec![]))), }; let (ot_send_result, direct_send_result, ot_receive_result) = futures::join!(ot_send_fut, direct_send_fut, ot_receive_fut); - ot_send_result?; + let label_sender = ot_send_result?; direct_send_result?; - let ot_receive_labels = ot_receive_result?; + let (ot_receive_labels, label_receiver) = ot_receive_result?; // Expect direct labels from peer let msg = expect_msg_or_err!( @@ -563,7 +623,7 @@ where let ev_labels = ActiveInputSet::new([ot_receive_labels, direct_received_labels, cached_labels].concat())?; - Ok((gen_labels, ev_labels)) + Ok(((gen_labels, ev_labels), (label_sender, label_receiver))) } #[cfg(feature = "mock")] @@ -571,40 +631,54 @@ mod mock { use super::*; use crate::protocol::{ garble::backend::RayonBackend, - ot::mock::{mock_ot_pair, MockOTReceiver, MockOTSender}, + ot::mock::{MockOTFactory, MockOTReceiver, MockOTSender}, }; use mpc_core::Block; use utils_aio::duplex::DuplexChannel; - pub type MockDualExLeader = - DualExLeader, MockOTReceiver>; - pub type MockDualExFollower = - DualExFollower, MockOTReceiver>; + pub type MockDualExLeader = DualExLeader< + S, + RayonBackend, + MockOTFactory, + MockOTFactory, + MockOTSender, + MockOTReceiver, + >; + pub type MockDualExFollower = DualExFollower< + S, + RayonBackend, + MockOTFactory, + MockOTFactory, + MockOTSender, + MockOTReceiver, + >; pub fn mock_dualex_pair( + config: DualExConfig, circ: Arc, ) -> ( MockDualExLeader, MockDualExFollower, ) { let (leader_channel, follower_channel) = DuplexChannel::::new(); - let (leader_sender, follower_receiver) = mock_ot_pair(); - let (follower_sender, leader_receiver) = mock_ot_pair(); + let ot_factory = MockOTFactory::new(); let leader = DualExLeader::new( + config.clone(), circ.clone(), Box::new(leader_channel), RayonBackend, - Some(leader_sender), - Some(leader_receiver), + ot_factory.clone(), + ot_factory.clone(), ); let follower = DualExFollower::new( + config, circ, Box::new(follower_channel), RayonBackend, - Some(follower_sender), - Some(follower_receiver), + ot_factory.clone(), + ot_factory.clone(), ); (leader, follower) @@ -618,6 +692,7 @@ pub use mock::mock_dualex_pair; mod tests { use super::*; use mpc_circuits::ADDER_64; + use mpc_core::garble::exec::dual::DualExConfigBuilder; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; @@ -625,7 +700,11 @@ mod tests { async fn test_dualex() { let mut rng = ChaCha12Rng::seed_from_u64(0); let circ = Circuit::load_bytes(ADDER_64).unwrap(); - let (leader, follower) = mock_dualex_pair(circ.clone()); + let config = DualExConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + let (leader, follower) = mock_dualex_pair(config, circ.clone()); let leader_input = circ.input(0).unwrap().to_value(1u64).unwrap(); let follower_input = circ.input(1).unwrap().to_value(2u64).unwrap(); diff --git a/mpc/mpc-aio/src/protocol/garble/mod.rs b/mpc/mpc-aio/src/protocol/garble/mod.rs index d79cf16de..d0f2fa777 100644 --- a/mpc/mpc-aio/src/protocol/garble/mod.rs +++ b/mpc/mpc-aio/src/protocol/garble/mod.rs @@ -28,6 +28,8 @@ pub enum GCError { IOError(#[from] std::io::Error), #[error("ot error")] OTError(#[from] OTError), + #[error("OTFactoryError: {0:?}")] + OTFactoryError(#[from] crate::protocol::ot::OTFactoryError), #[error("Received unexpected message: {0:?}")] Unexpected(GarbleMessage), #[error("backend error")] diff --git a/mpc/mpc-core/src/garble/exec/deap/config.rs b/mpc/mpc-core/src/garble/exec/deap/config.rs new file mode 100644 index 000000000..32e7e171a --- /dev/null +++ b/mpc/mpc-core/src/garble/exec/deap/config.rs @@ -0,0 +1,12 @@ +use derive_builder::Builder; + +#[derive(Debug, Clone, Builder)] +pub struct DEAPConfig { + id: String, +} + +impl DEAPConfig { + pub fn id(&self) -> &str { + &self.id + } +} diff --git a/mpc/mpc-core/src/garble/exec/deap/mod.rs b/mpc/mpc-core/src/garble/exec/deap/mod.rs index b967a3433..7623e8638 100644 --- a/mpc/mpc-core/src/garble/exec/deap/mod.rs +++ b/mpc/mpc-core/src/garble/exec/deap/mod.rs @@ -1,6 +1,8 @@ +mod config; mod follower; mod leader; +pub use config::{DEAPConfig, DEAPConfigBuilder, DEAPConfigBuilderError}; pub use follower::{state as follower_state, DEAPFollower}; pub use leader::{state as leader_state, DEAPLeader}; diff --git a/mpc/mpc-core/src/garble/exec/dual/config.rs b/mpc/mpc-core/src/garble/exec/dual/config.rs new file mode 100644 index 000000000..b5d20aa0b --- /dev/null +++ b/mpc/mpc-core/src/garble/exec/dual/config.rs @@ -0,0 +1,12 @@ +use derive_builder::Builder; + +#[derive(Debug, Clone, Builder)] +pub struct DualExConfig { + id: String, +} + +impl DualExConfig { + pub fn id(&self) -> &str { + &self.id + } +} diff --git a/mpc/mpc-core/src/garble/exec/dual/mod.rs b/mpc/mpc-core/src/garble/exec/dual/mod.rs index 183175050..07fc9d410 100644 --- a/mpc/mpc-core/src/garble/exec/dual/mod.rs +++ b/mpc/mpc-core/src/garble/exec/dual/mod.rs @@ -6,9 +6,11 @@ //! malicious. Such leakage, however, will be detected by the [`DualExFollower`] during the //! equality check. +mod config; mod follower; mod leader; +pub use config::{DualExConfig, DualExConfigBuilder, DualExConfigBuilderError}; pub use follower::{state as follower_state, DualExFollower}; pub use leader::{state as leader_state, DualExLeader};