DualEx use OT Factory (#177)

* refactor dualex to use ot factory

* remove unnecessary clone
This commit is contained in:
sinu.eth
2023-01-30 10:15:31 -08:00
committed by GitHub
parent a92342cd8b
commit f4b80fdcc9
11 changed files with 400 additions and 193 deletions

View File

@@ -1,14 +1,18 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use mpc_aio::protocol::garble::exec::deap::mock_deap_pair;
use mpc_circuits::{Circuit, WireGroup, AES_128_REVERSE};
use mpc_core::garble::FullInputSet;
use mpc_core::garble::{exec::deap::DEAPConfigBuilder, FullInputSet};
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
use std::sync::Arc;
async fn bench_deap(circ: Arc<Circuit>) {
let mut rng = ChaCha12Rng::seed_from_u64(0);
let (leader, follower) = mock_deap_pair(circ.clone());
let config = DEAPConfigBuilder::default()
.id("bench".to_string())
.build()
.unwrap();
let (leader, follower) = mock_deap_pair(config, circ.clone());
let leader_input = circ.input(0).unwrap().to_value(vec![0u8; 16]).unwrap();
let follower_input = circ.input(1).unwrap().to_value(vec![0u8; 16]).unwrap();

View File

@@ -1,14 +1,18 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use mpc_aio::protocol::garble::exec::dual::mock_dualex_pair;
use mpc_circuits::{Circuit, WireGroup, AES_128_REVERSE};
use mpc_core::garble::FullInputSet;
use mpc_core::garble::{exec::dual::DualExConfigBuilder, FullInputSet};
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
use std::sync::Arc;
async fn bench_dualex(circ: Arc<Circuit>) {
let mut rng = ChaCha12Rng::seed_from_u64(0);
let (leader, follower) = mock_dualex_pair(circ.clone());
let config = DualExConfigBuilder::default()
.id("bench".to_string())
.build()
.unwrap();
let (leader, follower) = mock_dualex_pair(config, circ.clone());
let leader_input = circ.input(0).unwrap().to_value(vec![0u8; 16]).unwrap();
let follower_input = circ.input(1).unwrap().to_value(vec![0u8; 16]).unwrap();

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};
use crate::protocol::{
garble::{Evaluator, GCError, GarbleChannel, GarbleMessage, Generator},
ot::{ObliviousReceive, ObliviousReveal, ObliviousSend},
ot::{OTFactoryError, ObliviousReceive, ObliviousReveal, ObliviousSend},
};
use futures::{SinkExt, StreamExt};
use mpc_circuits::{Circuit, Input, InputValue, OutputValue};
use mpc_core::garble::{
exec::deap as core, gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput,
FullInputSet, GarbledCircuit,
use mpc_core::{
garble::{
exec::deap::{self as core, DEAPConfig},
gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, FullInputSet,
GarbledCircuit,
},
ot::config::{OTReceiverConfig, OTSenderConfig},
};
use utils_aio::expect_msg_or_err;
use utils_aio::{expect_msg_or_err, factory::AsyncFactory};
use super::setup_inputs_with;
@@ -21,65 +25,75 @@ pub mod state {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::LabelSetup {}
impl Sealed for super::Executed {}
impl<LS> Sealed for super::LabelSetup<LS> {}
impl<LS> Sealed for super::Executed<LS> {}
}
pub trait State: sealed::Sealed {}
pub struct Initialized;
pub struct LabelSetup {
pub struct LabelSetup<LS> {
pub(crate) gen_labels: FullInputSet,
pub(crate) ev_labels: ActiveInputSet,
pub(crate) label_sender: Option<LS>,
}
pub struct Executed {
pub struct Executed<LS> {
pub(super) core: core::DEAPFollower<core::follower_state::Open>,
pub(crate) label_sender: Option<LS>,
}
impl State for Initialized {}
impl State for LabelSetup {}
impl State for Executed {}
impl<LS> State for LabelSetup<LS> {}
impl<LS> State for Executed<LS> {}
}
use state::*;
pub struct DEAPFollower<S, B, LS, LR>
pub struct DEAPFollower<S, B, LSF, LRF, LS, LR>
where
S: State,
B: Generator + Evaluator,
LS: ObliviousSend<FullEncodedInput> + ObliviousReveal,
LR: ObliviousReceive<InputValue, ActiveEncodedInput>,
{
config: DEAPConfig,
state: S,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
label_sender_factory: LSF,
label_receiver_factory: LRF,
_label_sender: PhantomData<LS>,
_label_receiver: PhantomData<LR>,
}
impl<B, LS, LR> DEAPFollower<Initialized, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPFollower<Initialized, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LSF: AsyncFactory<LS, Config = OTSenderConfig, Error = OTFactoryError> + Send,
LRF: AsyncFactory<LR, Config = OTReceiverConfig, Error = OTFactoryError> + Send,
LS: ObliviousSend<FullEncodedInput> + ObliviousReveal + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
pub fn new(
config: DEAPConfig,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
label_sender_factory: LSF,
label_receiver_factory: LRF,
) -> Self {
DEAPFollower {
config,
state: Initialized,
circ,
channel,
backend,
label_sender,
label_receiver,
label_sender_factory,
label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
}
}
@@ -98,11 +112,16 @@ where
ot_send_inputs: Vec<Input>,
ot_receive_inputs: Vec<InputValue>,
cached_labels: Vec<ActiveEncodedInput>,
) -> Result<DEAPFollower<LabelSetup, B, LS, LR>, GCError> {
let (gen_labels, ev_labels) = setup_inputs_with(
) -> Result<DEAPFollower<LabelSetup<LS>, B, LSF, LRF, LS, LR>, GCError> {
let label_sender_id = format!("{}/ot/1", self.config.id());
let label_receiver_id = format!("{}/ot/0", self.config.id());
let ((gen_labels, ev_labels), (label_sender, _)) = setup_inputs_with(
label_sender_id,
label_receiver_id,
&mut self.channel,
self.label_sender.as_mut(),
self.label_receiver.as_mut(),
&mut self.label_sender_factory,
&mut self.label_receiver_factory,
gen_labels,
gen_inputs,
ot_send_inputs,
@@ -112,24 +131,28 @@ where
.await?;
Ok(DEAPFollower {
config: self.config,
state: LabelSetup {
gen_labels,
ev_labels,
label_sender,
},
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
})
}
}
impl<B, LS, LR> DEAPFollower<LabelSetup, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPFollower<LabelSetup<LS>, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LS: ObliviousSend<FullEncodedInput> + ObliviousReveal + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
/// Execute first phase of the protocol, returning the _purported_ circuit output.
///
@@ -137,7 +160,13 @@ where
/// validated in the next phase.
pub async fn execute(
self,
) -> Result<(Vec<OutputValue>, DEAPFollower<Executed, B, LS, LR>), GCError> {
) -> Result<
(
Vec<OutputValue>,
DEAPFollower<Executed<LS>, B, LSF, LRF, LS, LR>,
),
GCError,
> {
// Discard the summary
let (output, _, follower) = self.execute_and_summarize().await?;
Ok((output, follower))
@@ -154,7 +183,7 @@ where
(
Vec<OutputValue>,
GarbledCircuit<gc_state::EvaluatedSummary>,
DEAPFollower<Executed, B, LS, LR>,
DEAPFollower<Executed<LS>, B, LSF, LRF, LS, LR>,
),
GCError,
> {
@@ -213,18 +242,25 @@ where
purported_output,
evaluated_summary,
DEAPFollower {
state: Executed { core: follower },
config: self.config,
state: Executed {
core: follower,
label_sender: self.state.label_sender,
},
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
},
))
}
}
impl<B, LS, LR> DEAPFollower<Executed, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPFollower<Executed<LS>, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LS: ObliviousSend<FullEncodedInput> + ObliviousReveal + Send,
@@ -248,7 +284,7 @@ where
.await?;
// Open our OTs to leader
if let Some(label_sender) = self.label_sender.take() {
if let Some(label_sender) = self.state.label_sender.take() {
label_sender.reveal().await?;
}

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};
use crate::protocol::{
garble::{Compressor, Evaluator, GCError, GarbleChannel, GarbleMessage, Generator, Validator},
ot::{ObliviousReceive, ObliviousSend, ObliviousVerify},
ot::{OTFactoryError, ObliviousReceive, ObliviousSend, ObliviousVerify},
};
use futures::{future::ready, SinkExt, StreamExt};
use futures::{SinkExt, StreamExt};
use mpc_circuits::{Circuit, Input, InputValue, OutputValue, WireGroup};
use mpc_core::garble::{
exec::deap as core, gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput,
FullInputSet, GarbledCircuit,
use mpc_core::{
garble::{
exec::deap::{self as core, DEAPConfig},
gc_state, ActiveEncodedInput, ActiveInputSet, FullEncodedInput, FullInputSet,
GarbledCircuit,
},
ot::config::{OTReceiverConfig, OTSenderConfig},
};
use utils_aio::expect_msg_or_err;
use utils_aio::{expect_msg_or_err, factory::AsyncFactory};
use super::setup_inputs_with;
@@ -21,28 +25,30 @@ pub mod state {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::LabelSetup {}
impl Sealed for super::Executed {}
impl<LR> Sealed for super::LabelSetup<LR> {}
impl<LR> Sealed for super::Executed<LR> {}
}
pub trait State: sealed::Sealed {}
pub struct Initialized;
pub struct LabelSetup {
pub struct LabelSetup<LR> {
pub(crate) gen_labels: FullInputSet,
pub(crate) ev_labels: ActiveInputSet,
pub(crate) input_state: InputState,
pub(crate) label_receiver: Option<LR>,
}
pub struct Executed {
pub struct Executed<LR> {
pub(super) core: core::DEAPLeader<core::leader_state::Validate>,
pub(crate) input_state: InputState,
pub(crate) label_receiver: Option<LR>,
}
impl State for Initialized {}
impl State for LabelSetup {}
impl State for Executed {}
impl<LR> State for LabelSetup<LR> {}
impl<LR> State for Executed<LR> {}
pub(crate) struct InputState {
pub(crate) ot_receive_inputs: Vec<Input>,
@@ -51,41 +57,48 @@ pub mod state {
use state::*;
pub struct DEAPLeader<S, B, LS, LR>
pub struct DEAPLeader<S, B, LSF, LRF, LS, LR>
where
S: State,
B: Generator + Evaluator,
LS: ObliviousSend<FullEncodedInput>,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + ObliviousVerify<FullEncodedInput>,
{
config: DEAPConfig,
state: S,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
label_sender_factory: LSF,
label_receiver_factory: LRF,
_label_sender: PhantomData<LS>,
_label_receiver: PhantomData<LR>,
}
impl<B, LS, LR> DEAPLeader<Initialized, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPLeader<Initialized, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Compressor + Validator + Send,
LSF: AsyncFactory<LS, Config = OTSenderConfig, Error = OTFactoryError> + Send,
LRF: AsyncFactory<LR, Config = OTReceiverConfig, Error = OTFactoryError> + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + ObliviousVerify<FullEncodedInput> + Send,
{
pub fn new(
config: DEAPConfig,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
) -> DEAPLeader<Initialized, B, LS, LR> {
label_sender_factory: LSF,
label_receiver_factory: LRF,
) -> DEAPLeader<Initialized, B, LSF, LRF, LS, LR> {
DEAPLeader {
config,
state: Initialized,
circ,
channel,
backend,
label_sender,
label_receiver,
label_sender_factory,
label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
}
}
@@ -104,11 +117,16 @@ where
ot_send_inputs: Vec<Input>,
ot_receive_inputs: Vec<InputValue>,
cached_labels: Vec<ActiveEncodedInput>,
) -> Result<DEAPLeader<LabelSetup, B, LS, LR>, GCError> {
let (gen_labels, ev_labels) = setup_inputs_with(
) -> Result<DEAPLeader<LabelSetup<LR>, B, LSF, LRF, LS, LR>, GCError> {
let label_sender_id = format!("{}/ot/0", self.config.id());
let label_receiver_id = format!("{}/ot/1", self.config.id());
let ((gen_labels, ev_labels), (_, label_receiver)) = setup_inputs_with(
label_sender_id,
label_receiver_id,
&mut self.channel,
self.label_sender.as_mut(),
self.label_receiver.as_mut(),
&mut self.label_sender_factory,
&mut self.label_receiver_factory,
gen_labels,
gen_inputs,
ot_send_inputs,
@@ -118,6 +136,7 @@ where
.await?;
Ok(DEAPLeader {
config: self.config,
state: LabelSetup {
gen_labels,
ev_labels,
@@ -127,26 +146,34 @@ where
.map(|v| v.group().clone())
.collect(),
},
label_receiver,
},
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
})
}
}
impl<B, LS, LR> DEAPLeader<LabelSetup, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPLeader<LabelSetup<LR>, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Compressor + Validator + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + ObliviousVerify<FullEncodedInput> + Send,
{
/// Execute first phase of the protocol, returning the authenticated output.
pub async fn execute(
self,
) -> Result<(Vec<OutputValue>, DEAPLeader<Executed, B, LS, LR>), GCError> {
) -> Result<
(
Vec<OutputValue>,
DEAPLeader<Executed<LR>, B, LSF, LRF, LS, LR>,
),
GCError,
> {
// Discard summary
let (output, _, leader) = self.execute_and_summarize().await?;
Ok((output, leader))
@@ -162,7 +189,7 @@ where
(
Vec<OutputValue>,
GarbledCircuit<gc_state::EvaluatedSummary>,
DEAPLeader<Executed, B, LS, LR>,
DEAPLeader<Executed<LR>, B, LSF, LRF, LS, LR>,
),
GCError,
> {
@@ -222,21 +249,25 @@ where
output,
gc_evaluated_summary,
DEAPLeader {
config: self.config,
state: Executed {
core: leader,
input_state: self.state.input_state,
label_receiver: self.state.label_receiver,
},
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
},
))
}
}
impl<B, LS, LR> DEAPLeader<Executed, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DEAPLeader<Executed<LR>, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Compressor + Validator + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
@@ -268,27 +299,32 @@ where
let gc_validate_fut = self.backend.validate_compressed(gc_cmp, opening);
// If we did not receive any inputs via OT, we can skip the OT validation
let labels_validate_fut = if self.state.input_state.ot_receive_inputs.is_empty() {
Box::pin(ready(Ok(())))
} else {
let Some(label_receiver) = self.label_receiver.take() else {
return Err(GCError::MissingOTReceiver);
};
let labels_validate_fut = async move {
if self.state.input_state.ot_receive_inputs.is_empty() {
Ok(())
} else {
let Some(label_receiver) = self.state.label_receiver.take() else {
return Err(GCError::MissingOTReceiver);
};
let ot_received = self
.state
.input_state
.ot_receive_inputs
.iter()
.map(|input| {
input_labels
.get(input.index())
.expect("Input id should be valid")
})
.cloned()
.collect::<Vec<_>>();
let ot_received = self
.state
.input_state
.ot_receive_inputs
.iter()
.map(|input| {
input_labels
.get(input.index())
.expect("Input id should be valid")
})
.cloned()
.collect::<Vec<_>>();
label_receiver.verify(ot_received)
label_receiver
.verify(ot_received)
.await
.map_err(GCError::from)
}
};
let (gc_validate_result, labels_validate_result) =

View File

@@ -14,41 +14,55 @@ mod mock {
use super::*;
use crate::protocol::{
garble::backend::RayonBackend,
ot::mock::{mock_ot_pair, MockOTReceiver, MockOTSender},
ot::mock::{MockOTFactory, MockOTReceiver, MockOTSender},
};
use mpc_circuits::Circuit;
use mpc_core::{msgs::garble::GarbleMessage, Block};
use mpc_core::{garble::exec::deap::DEAPConfig, msgs::garble::GarbleMessage, Block};
use utils_aio::duplex::DuplexChannel;
pub type MockDEAPLeader<S> =
DEAPLeader<S, RayonBackend, MockOTSender<Block>, MockOTReceiver<Block>>;
pub type MockDEAPFollower<S> =
DEAPFollower<S, RayonBackend, MockOTSender<Block>, MockOTReceiver<Block>>;
pub type MockDEAPLeader<S> = DEAPLeader<
S,
RayonBackend,
MockOTFactory<Block>,
MockOTFactory<Block>,
MockOTSender<Block>,
MockOTReceiver<Block>,
>;
pub type MockDEAPFollower<S> = DEAPFollower<
S,
RayonBackend,
MockOTFactory<Block>,
MockOTFactory<Block>,
MockOTSender<Block>,
MockOTReceiver<Block>,
>;
pub fn mock_deap_pair(
config: DEAPConfig,
circ: Arc<Circuit>,
) -> (
MockDEAPLeader<leader_state::Initialized>,
MockDEAPFollower<follower_state::Initialized>,
) {
let (leader_channel, follower_channel) = DuplexChannel::<GarbleMessage>::new();
let (leader_sender, follower_receiver) = mock_ot_pair();
let (follower_sender, leader_receiver) = mock_ot_pair();
let ot_factory = MockOTFactory::new();
let leader = DEAPLeader::new(
config.clone(),
circ.clone(),
Box::new(leader_channel),
RayonBackend,
Some(leader_sender),
Some(leader_receiver),
ot_factory.clone(),
ot_factory.clone(),
);
let follower = DEAPFollower::new(
config,
circ,
Box::new(follower_channel),
RayonBackend,
Some(follower_sender),
Some(follower_receiver),
ot_factory.clone(),
ot_factory,
);
(leader, follower)
@@ -62,15 +76,19 @@ pub use mock::mock_deap_pair;
mod tests {
use super::*;
use mpc_circuits::{Circuit, WireGroup, ADDER_64};
use mpc_core::garble::FullInputSet;
use mpc_core::garble::{exec::deap::DEAPConfigBuilder, FullInputSet};
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
#[tokio::test]
async fn test_deap() {
let config = DEAPConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let mut rng = ChaCha12Rng::seed_from_u64(0);
let circ = Circuit::load_bytes(ADDER_64).unwrap();
let (leader, follower) = mock_deap_pair(circ.clone());
let (leader, follower) = mock_deap_pair(config, circ.clone());
let leader_input = circ.input(0).unwrap().to_value(1u64).unwrap();
let follower_input = circ.input(1).unwrap().to_value(2u64).unwrap();

View File

@@ -6,19 +6,25 @@
//! malicious. Such leakage, however, will be detected by the [`DualExFollower`] during the
//! equality check.
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};
use crate::protocol::{
garble::{Evaluator, GCError, GarbleChannel, GarbleMessage, Generator},
ot::{ObliviousReceive, ObliviousSend},
ot::{OTFactoryError, ObliviousReceive, ObliviousSend},
};
use futures::{future::ready, SinkExt, StreamExt};
use futures::{SinkExt, StreamExt};
use mpc_circuits::{Circuit, Input, InputValue, OutputValue, WireGroup};
use mpc_core::garble::{
exec::dual as core, gc_state, ActiveEncodedInput, ActiveInputSet, Error as CoreError,
FullEncodedInput, FullInputSet, GarbledCircuit,
use mpc_core::{
garble::{
exec::dual::{self as core, DualExConfig},
gc_state, ActiveEncodedInput, ActiveInputSet, Error as CoreError, FullEncodedInput,
FullInputSet, GarbledCircuit,
},
ot::config::{
OTReceiverConfig, OTReceiverConfigBuilder, OTSenderConfig, OTSenderConfigBuilder,
},
};
use utils_aio::expect_msg_or_err;
use utils_aio::{expect_msg_or_err, factory::AsyncFactory};
mod state {
use super::*;
@@ -45,42 +51,49 @@ mod state {
use state::*;
pub struct DualExLeader<S, B, LS, LR>
pub struct DualExLeader<S, B, LSF, LRF, LS, LR>
where
S: State,
B: Generator + Evaluator,
LS: ObliviousSend<FullEncodedInput>,
LR: ObliviousReceive<InputValue, ActiveEncodedInput>,
{
config: DualExConfig,
state: S,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
label_sender_factory: LSF,
label_receiver_factory: LRF,
_label_sender: PhantomData<LS>,
_label_receiver: PhantomData<LR>,
}
impl<B, LS, LR> DualExLeader<Initialized, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DualExLeader<Initialized, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LSF: AsyncFactory<LS, Config = OTSenderConfig, Error = OTFactoryError> + Send,
LRF: AsyncFactory<LR, Config = OTReceiverConfig, Error = OTFactoryError> + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
/// Create a new DualExLeader
pub fn new(
config: DualExConfig,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
) -> DualExLeader<Initialized, B, LS, LR> {
label_sender_factory: LSF,
label_receiver_factory: LRF,
) -> DualExLeader<Initialized, B, LSF, LRF, LS, LR> {
DualExLeader {
config,
state: Initialized,
circ,
channel,
backend,
label_sender,
label_receiver,
label_sender_factory,
label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
}
}
@@ -99,11 +112,16 @@ where
ot_send_inputs: Vec<Input>,
ot_receive_inputs: Vec<InputValue>,
cached_labels: Vec<ActiveEncodedInput>,
) -> Result<DualExLeader<LabelSetup, B, LS, LR>, GCError> {
let (gen_labels, ev_labels) = setup_inputs_with(
) -> Result<DualExLeader<LabelSetup, B, LSF, LRF, LS, LR>, GCError> {
let label_sender_id = format!("{}/ot/0", self.config.id());
let label_receiver_id = format!("{}/ot/1", self.config.id());
let ((gen_labels, ev_labels), _) = setup_inputs_with(
label_sender_id,
label_receiver_id,
&mut self.channel,
self.label_sender.as_mut(),
self.label_receiver.as_mut(),
&mut self.label_sender_factory,
&mut self.label_receiver_factory,
gen_labels,
gen_inputs,
ot_send_inputs,
@@ -113,6 +131,7 @@ where
.await?;
Ok(DualExLeader {
config: self.config,
state: LabelSetup {
gen_labels,
ev_labels,
@@ -120,17 +139,17 @@ where
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
})
}
}
impl<B, LS, LR> DualExLeader<LabelSetup, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DualExLeader<LabelSetup, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
/// Execute dual execution protocol
///
@@ -261,42 +280,49 @@ where
}
}
pub struct DualExFollower<S, B, LS, LR>
pub struct DualExFollower<S, B, LSF, LRF, LS, LR>
where
S: State,
B: Generator + Evaluator,
LS: ObliviousSend<FullEncodedInput>,
LR: ObliviousReceive<InputValue, ActiveEncodedInput>,
{
config: DualExConfig,
state: S,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
label_sender_factory: LSF,
label_receiver_factory: LRF,
_label_sender: PhantomData<LS>,
_label_receiver: PhantomData<LR>,
}
impl<B, LS, LR> DualExFollower<Initialized, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DualExFollower<Initialized, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LSF: AsyncFactory<LS, Config = OTSenderConfig, Error = OTFactoryError> + Send,
LRF: AsyncFactory<LR, Config = OTReceiverConfig, Error = OTFactoryError> + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
/// Create a new DualExFollower
pub fn new(
config: DualExConfig,
circ: Arc<Circuit>,
channel: GarbleChannel,
backend: B,
label_sender: Option<LS>,
label_receiver: Option<LR>,
) -> DualExFollower<Initialized, B, LS, LR> {
label_sender_factory: LSF,
label_receiver_factory: LRF,
) -> DualExFollower<Initialized, B, LSF, LRF, LS, LR> {
DualExFollower {
config,
state: Initialized,
circ,
channel,
backend,
label_sender,
label_receiver,
label_sender_factory,
label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
}
}
@@ -315,11 +341,16 @@ where
ot_send_inputs: Vec<Input>,
ot_receive_inputs: Vec<InputValue>,
cached_labels: Vec<ActiveEncodedInput>,
) -> Result<DualExFollower<LabelSetup, B, LS, LR>, GCError> {
let (gen_labels, ev_labels) = setup_inputs_with(
) -> Result<DualExFollower<LabelSetup, B, LSF, LRF, LS, LR>, GCError> {
let label_sender_id = format!("{}/ot/1", self.config.id());
let label_receiver_id = format!("{}/ot/0", self.config.id());
let ((gen_labels, ev_labels), _) = setup_inputs_with(
label_sender_id,
label_receiver_id,
&mut self.channel,
self.label_sender.as_mut(),
self.label_receiver.as_mut(),
&mut self.label_sender_factory,
&mut self.label_receiver_factory,
gen_labels,
gen_inputs,
ot_send_inputs,
@@ -329,6 +360,7 @@ where
.await?;
Ok(DualExFollower {
config: self.config,
state: LabelSetup {
gen_labels,
ev_labels,
@@ -336,17 +368,17 @@ where
circ: self.circ,
channel: self.channel,
backend: self.backend,
label_sender: self.label_sender,
label_receiver: self.label_receiver,
label_sender_factory: self.label_sender_factory,
label_receiver_factory: self.label_receiver_factory,
_label_sender: PhantomData,
_label_receiver: PhantomData,
})
}
}
impl<B, LS, LR> DualExFollower<LabelSetup, B, LS, LR>
impl<B, LSF, LRF, LS, LR> DualExFollower<LabelSetup, B, LSF, LRF, LS, LR>
where
B: Generator + Evaluator + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
/// Execute dual execution protocol
///
@@ -479,17 +511,21 @@ where
}
/// Set up input labels by exchanging directly and via oblivious transfer.
pub async fn setup_inputs_with<LS, LR>(
pub async fn setup_inputs_with<LSF, LRF, LS, LR>(
label_sender_id: String,
label_receiver_id: String,
channel: &mut GarbleChannel,
label_sender: Option<&mut LS>,
label_receiver: Option<&mut LR>,
label_sender_factory: &mut LSF,
label_receiver_factory: &mut LRF,
gen_labels: FullInputSet,
gen_inputs: Vec<InputValue>,
ot_send_inputs: Vec<Input>,
ot_receive_inputs: Vec<InputValue>,
cached_labels: Vec<ActiveEncodedInput>,
) -> Result<(FullInputSet, ActiveInputSet), GCError>
) -> Result<((FullInputSet, ActiveInputSet), (Option<LS>, Option<LR>)), GCError>
where
LSF: AsyncFactory<LS, Config = OTSenderConfig, Error = OTFactoryError> + Send,
LRF: AsyncFactory<LR, Config = OTReceiverConfig, Error = OTFactoryError> + Send,
LS: ObliviousSend<FullEncodedInput> + Send,
LR: ObliviousReceive<InputValue, ActiveEncodedInput> + Send,
{
@@ -514,12 +550,25 @@ where
// Concurrently execute oblivious transfers and direct label sending
// If there are no labels to be sent via OT, we can skip the OT protocol
let ot_send_fut = match label_sender {
Some(label_sender) if ot_send_labels.len() > 0 => label_sender.send(ot_send_labels),
None if ot_send_labels.len() > 0 => {
return Err(GCError::MissingOTSender);
let ot_send_fut = async move {
if ot_send_labels.len() > 0 {
let count = ot_send_labels.iter().map(|labels| labels.len()).sum();
let sender_config = OTSenderConfigBuilder::default()
.count(count)
.build()
.expect("OTSenderConfig should be valid");
let mut label_sender = label_sender_factory
.create(label_sender_id, sender_config)
.await?;
let _ = label_sender.send(ot_send_labels).await?;
Result::<_, GCError>::Ok(Some(label_sender))
} else {
Result::<_, GCError>::Ok(None)
}
_ => Box::pin(ready(Ok(()))),
};
let direct_send_fut = channel.send(GarbleMessage::InputLabels(
@@ -530,22 +579,33 @@ where
));
// If there are no labels to be received via OT, we can skip the OT protocol
let ot_receive_fut = match label_receiver {
Some(label_receiver) if ot_receive_inputs.len() > 0 => {
label_receiver.receive(ot_receive_inputs)
let ot_receive_fut = async move {
if ot_receive_inputs.len() > 0 {
let count = ot_receive_inputs.iter().map(|input| input.len()).sum();
let receiver_config = OTReceiverConfigBuilder::default()
.count(count)
.build()
.expect("OTReceiverConfig should be valid");
let mut label_receiver = label_receiver_factory
.create(label_receiver_id, receiver_config)
.await?;
let ot_receive_labels = label_receiver.receive(ot_receive_inputs).await?;
Result::<_, GCError>::Ok((ot_receive_labels, Some(label_receiver)))
} else {
Result::<_, GCError>::Ok((vec![], None))
}
None if ot_receive_inputs.len() > 0 => {
return Err(GCError::MissingOTReceiver);
}
_ => Box::pin(ready(Ok(vec![]))),
};
let (ot_send_result, direct_send_result, ot_receive_result) =
futures::join!(ot_send_fut, direct_send_fut, ot_receive_fut);
ot_send_result?;
let label_sender = ot_send_result?;
direct_send_result?;
let ot_receive_labels = ot_receive_result?;
let (ot_receive_labels, label_receiver) = ot_receive_result?;
// Expect direct labels from peer
let msg = expect_msg_or_err!(
@@ -563,7 +623,7 @@ where
let ev_labels =
ActiveInputSet::new([ot_receive_labels, direct_received_labels, cached_labels].concat())?;
Ok((gen_labels, ev_labels))
Ok(((gen_labels, ev_labels), (label_sender, label_receiver)))
}
#[cfg(feature = "mock")]
@@ -571,40 +631,54 @@ mod mock {
use super::*;
use crate::protocol::{
garble::backend::RayonBackend,
ot::mock::{mock_ot_pair, MockOTReceiver, MockOTSender},
ot::mock::{MockOTFactory, MockOTReceiver, MockOTSender},
};
use mpc_core::Block;
use utils_aio::duplex::DuplexChannel;
pub type MockDualExLeader<S> =
DualExLeader<S, RayonBackend, MockOTSender<Block>, MockOTReceiver<Block>>;
pub type MockDualExFollower<S> =
DualExFollower<S, RayonBackend, MockOTSender<Block>, MockOTReceiver<Block>>;
pub type MockDualExLeader<S> = DualExLeader<
S,
RayonBackend,
MockOTFactory<Block>,
MockOTFactory<Block>,
MockOTSender<Block>,
MockOTReceiver<Block>,
>;
pub type MockDualExFollower<S> = DualExFollower<
S,
RayonBackend,
MockOTFactory<Block>,
MockOTFactory<Block>,
MockOTSender<Block>,
MockOTReceiver<Block>,
>;
pub fn mock_dualex_pair(
config: DualExConfig,
circ: Arc<Circuit>,
) -> (
MockDualExLeader<Initialized>,
MockDualExFollower<Initialized>,
) {
let (leader_channel, follower_channel) = DuplexChannel::<GarbleMessage>::new();
let (leader_sender, follower_receiver) = mock_ot_pair();
let (follower_sender, leader_receiver) = mock_ot_pair();
let ot_factory = MockOTFactory::new();
let leader = DualExLeader::new(
config.clone(),
circ.clone(),
Box::new(leader_channel),
RayonBackend,
Some(leader_sender),
Some(leader_receiver),
ot_factory.clone(),
ot_factory.clone(),
);
let follower = DualExFollower::new(
config,
circ,
Box::new(follower_channel),
RayonBackend,
Some(follower_sender),
Some(follower_receiver),
ot_factory.clone(),
ot_factory.clone(),
);
(leader, follower)
@@ -618,6 +692,7 @@ pub use mock::mock_dualex_pair;
mod tests {
use super::*;
use mpc_circuits::ADDER_64;
use mpc_core::garble::exec::dual::DualExConfigBuilder;
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
@@ -625,7 +700,11 @@ mod tests {
async fn test_dualex() {
let mut rng = ChaCha12Rng::seed_from_u64(0);
let circ = Circuit::load_bytes(ADDER_64).unwrap();
let (leader, follower) = mock_dualex_pair(circ.clone());
let config = DualExConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let (leader, follower) = mock_dualex_pair(config, circ.clone());
let leader_input = circ.input(0).unwrap().to_value(1u64).unwrap();
let follower_input = circ.input(1).unwrap().to_value(2u64).unwrap();

View File

@@ -28,6 +28,8 @@ pub enum GCError {
IOError(#[from] std::io::Error),
#[error("ot error")]
OTError(#[from] OTError),
#[error("OTFactoryError: {0:?}")]
OTFactoryError(#[from] crate::protocol::ot::OTFactoryError),
#[error("Received unexpected message: {0:?}")]
Unexpected(GarbleMessage),
#[error("backend error")]

View File

@@ -0,0 +1,12 @@
use derive_builder::Builder;
#[derive(Debug, Clone, Builder)]
pub struct DEAPConfig {
id: String,
}
impl DEAPConfig {
pub fn id(&self) -> &str {
&self.id
}
}

View File

@@ -1,6 +1,8 @@
mod config;
mod follower;
mod leader;
pub use config::{DEAPConfig, DEAPConfigBuilder, DEAPConfigBuilderError};
pub use follower::{state as follower_state, DEAPFollower};
pub use leader::{state as leader_state, DEAPLeader};

View File

@@ -0,0 +1,12 @@
use derive_builder::Builder;
#[derive(Debug, Clone, Builder)]
pub struct DualExConfig {
id: String,
}
impl DualExConfig {
pub fn id(&self) -> &str {
&self.id
}
}

View File

@@ -6,9 +6,11 @@
//! malicious. Such leakage, however, will be detected by the [`DualExFollower`] during the
//! equality check.
mod config;
mod follower;
mod leader;
pub use config::{DualExConfig, DualExConfigBuilder, DualExConfigBuilderError};
pub use follower::{state as follower_state, DualExFollower};
pub use leader::{state as leader_state, DualExLeader};