From 34f7e867c976c4bb64639275600a2154bf0916ea Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:02:07 -0800 Subject: [PATCH] Rayon Garbler (#115) * rayon garbler * comment RayonGarbler * rename some things, fix import * remove redundant ref * move and rename mock backend * rename error --- mpc-aio/Cargo.toml | 1 + mpc-aio/src/protocol/garble/backend/mod.rs | 52 ++++++++++++ mpc-aio/src/protocol/garble/backend/rayon.rs | 89 ++++++++++++++++++++ mpc-aio/src/protocol/garble/exec/dual.rs | 81 ++++++------------ mpc-aio/src/protocol/garble/mod.rs | 45 +--------- 5 files changed, 172 insertions(+), 96 deletions(-) create mode 100644 mpc-aio/src/protocol/garble/backend/mod.rs create mode 100644 mpc-aio/src/protocol/garble/backend/rayon.rs diff --git a/mpc-aio/Cargo.toml b/mpc-aio/Cargo.toml index 62494c04e..d391f1806 100644 --- a/mpc-aio/Cargo.toml +++ b/mpc-aio/Cargo.toml @@ -37,6 +37,7 @@ tokio = { workspace = true, features = [ mockall = "0.11" async-stream = "0.3" aes = { workspace = true, optional = true } +rayon = { workspace = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { workspace = true, features = ["js"] } diff --git a/mpc-aio/src/protocol/garble/backend/mod.rs b/mpc-aio/src/protocol/garble/backend/mod.rs new file mode 100644 index 000000000..30d5c2e86 --- /dev/null +++ b/mpc-aio/src/protocol/garble/backend/mod.rs @@ -0,0 +1,52 @@ +mod rayon; + +pub use self::rayon::RayonBackend; + +#[cfg(feature = "mock")] +mod mock { + use std::sync::Arc; + + use aes::{Aes128, NewBlockCipher}; + use async_trait::async_trait; + + use crate::protocol::garble::{Evaluator, GCError, Generator}; + use mpc_circuits::Circuit; + use mpc_core::garble::{ + Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair, + }; + + pub struct MockBackend; + + #[async_trait] + impl Generator for MockBackend { + async fn generate( + &mut self, + circ: Arc, + delta: Delta, + input_labels: &[InputLabels], + ) -> Result, GCError> { + let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); + Ok(GarbledCircuit::generate( + &cipher, + circ, + delta, + input_labels, + )?) + } + } + + #[async_trait] + impl Evaluator for MockBackend { + async fn evaluate( + &mut self, + circ: GarbledCircuit, + input_labels: &[InputLabels], + ) -> Result, GCError> { + let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); + Ok(circ.evaluate(&cipher, input_labels)?) + } + } +} + +#[cfg(feature = "mock")] +pub use mock::MockBackend; diff --git a/mpc-aio/src/protocol/garble/backend/rayon.rs b/mpc-aio/src/protocol/garble/backend/rayon.rs new file mode 100644 index 000000000..ba44cade7 --- /dev/null +++ b/mpc-aio/src/protocol/garble/backend/rayon.rs @@ -0,0 +1,89 @@ +use std::sync::Arc; + +use aes::{Aes128, NewBlockCipher}; +use async_trait::async_trait; +use futures::channel::oneshot; + +use mpc_circuits::Circuit; +use mpc_core::garble::{ + Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair, +}; + +use crate::protocol::garble::{Evaluator, GCError, Generator}; + +/// Garbler backend using Rayon to garble and evaluate circuits asynchronously and in parallel +pub struct RayonBackend; + +#[async_trait] +impl Generator for RayonBackend { + async fn generate( + &mut self, + circ: Arc, + delta: Delta, + input_labels: &[InputLabels], + ) -> Result, GCError> { + let (sender, receiver) = oneshot::channel(); + let input_labels = input_labels.to_vec(); + rayon::spawn(move || { + let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); + let gc = GarbledCircuit::generate(&cipher, circ, delta, &input_labels) + .map_err(GCError::from); + _ = sender.send(gc); + }); + receiver + .await + .map_err(|_| GCError::BackendError("channel error".to_string()))? + } +} + +#[async_trait] +impl Evaluator for RayonBackend { + async fn evaluate( + &mut self, + circ: GarbledCircuit, + input_labels: &[InputLabels], + ) -> Result, GCError> { + let (sender, receiver) = oneshot::channel(); + let input_labels = input_labels.to_vec(); + rayon::spawn(move || { + let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); + let ev = circ.evaluate(&cipher, &input_labels).map_err(GCError::from); + _ = sender.send(ev); + }); + receiver + .await + .map_err(|_| GCError::BackendError("channel error".to_string()))? + } +} + +#[cfg(test)] +mod test { + use mpc_circuits::ADDER_64; + use rand::thread_rng; + + use super::*; + + #[tokio::test] + async fn test_rayon_garbler() { + let circ = Arc::new(Circuit::load_bytes(ADDER_64).unwrap()); + let (input_labels, delta) = InputLabels::generate(&mut thread_rng(), &circ, None); + let gc = RayonBackend + .generate(circ.clone(), delta, &input_labels) + .await + .unwrap(); + + let input_labels = vec![ + input_labels[0] + .select(&circ.input(0).unwrap().to_value(0u64).unwrap()) + .unwrap(), + input_labels[1] + .select(&circ.input(1).unwrap().to_value(0u64).unwrap()) + .unwrap(), + ]; + + let _ = RayonBackend + .evaluate(gc.to_evaluator(&[], true, false), &input_labels) + .await + .unwrap(); + } +} diff --git a/mpc-aio/src/protocol/garble/exec/dual.rs b/mpc-aio/src/protocol/garble/exec/dual.rs index 9b63854c5..0a56742ef 100644 --- a/mpc-aio/src/protocol/garble/exec/dual.rs +++ b/mpc-aio/src/protocol/garble/exec/dual.rs @@ -20,38 +20,28 @@ use mpc_core::garble::{ }; use utils_aio::expect_msg_or_err; -pub struct DualExLeader +pub struct DualExLeader where - G: Generator, - E: Evaluator, + B: Generator + Evaluator, S: WireLabelOTSend, R: WireLabelOTReceive, { channel: GarbleChannel, - generator: G, - evaluator: E, + backend: B, label_sender: S, label_receiver: R, } -impl DualExLeader +impl DualExLeader where - G: Generator + Send, - E: Evaluator + Send, + B: Generator + Evaluator + Send, S: WireLabelOTSend + Send, R: WireLabelOTReceive + Send, { - pub fn new( - channel: GarbleChannel, - generator: G, - evaluator: E, - label_sender: S, - label_receiver: R, - ) -> Self { + pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self { Self { channel, - generator, - evaluator, + backend, label_sender, label_receiver, } @@ -59,10 +49,9 @@ where } #[async_trait] -impl ExecuteWithLabels for DualExLeader +impl ExecuteWithLabels for DualExLeader where - G: Generator + Send, - E: Evaluator + Send, + B: Generator + Evaluator + Send, S: WireLabelOTSend + Send, R: WireLabelOTReceive + Send, { @@ -75,8 +64,8 @@ where ) -> Result, GCError> { let leader = core::DualExLeader::new(circ.clone()); let full_gc = self - .generator - .generate(circ.clone(), delta, input_labels) + .backend + .generate(circ.clone(), delta, &input_labels) .await?; let (partial_gc, leader) = leader.from_full_circuit(inputs, full_gc)?; @@ -106,7 +95,7 @@ where let gc_ev = GarbledCircuit::::from_msg(circ, msg)?; let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?; - let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?; + let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?; let leader = leader.from_evaluated_circuit(evaluated_gc)?; let (commit, leader) = leader.commit(); @@ -132,38 +121,28 @@ where } } -pub struct DualExFollower +pub struct DualExFollower where - G: Generator, - E: Evaluator, + B: Generator + Evaluator, S: WireLabelOTSend, R: WireLabelOTReceive, { channel: GarbleChannel, - generator: G, - evaluator: E, + backend: B, label_sender: S, label_receiver: R, } -impl DualExFollower +impl DualExFollower where - G: Generator + Send, - E: Evaluator + Send, + B: Generator + Evaluator + Send, S: WireLabelOTSend + Send, R: WireLabelOTReceive + Send, { - pub fn new( - channel: GarbleChannel, - generator: G, - evaluator: E, - label_sender: S, - label_receiver: R, - ) -> Self { + pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self { Self { channel, - generator, - evaluator, + backend, label_sender, label_receiver, } @@ -171,10 +150,9 @@ where } #[async_trait] -impl ExecuteWithLabels for DualExFollower +impl ExecuteWithLabels for DualExFollower where - G: Generator + Send, - E: Evaluator + Send, + B: Generator + Evaluator + Send, S: WireLabelOTSend + Send, R: WireLabelOTReceive + Send, { @@ -187,8 +165,8 @@ where ) -> Result, GCError> { let follower = core::DualExFollower::new(circ.clone()); let full_gc = self - .generator - .generate(circ.clone(), delta, input_labels) + .backend + .generate(circ.clone(), delta, &input_labels) .await?; let (partial_gc, follower) = follower.from_full_circuit(inputs, full_gc)?; @@ -217,7 +195,7 @@ where let gc_ev = GarbledCircuit::::from_msg(circ, msg)?; let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?; - let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?; + let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?; let follower = follower.from_evaluated_circuit(evaluated_gc)?; let msg = expect_msg_or_err!( @@ -247,10 +225,7 @@ where #[cfg(feature = "mock")] mod mock { use super::*; - use crate::protocol::{ - garble::mock::{MockEvaluator, MockGenerator}, - ot::mock::mock_ot_pair, - }; + use crate::protocol::{garble::backend::MockBackend, ot::mock::mock_ot_pair}; use utils_aio::duplex::DuplexChannel; pub fn mock_dualex_pair() -> (impl ExecuteWithLabels, impl ExecuteWithLabels) { @@ -260,15 +235,13 @@ mod mock { let leader = DualExLeader::new( Box::new(leader_channel), - MockGenerator, - MockEvaluator, + MockBackend, leader_sender, leader_receiver, ); let follower = DualExFollower::new( Box::new(follower_channel), - MockGenerator, - MockEvaluator, + MockBackend, follower_sender, follower_receiver, ); diff --git a/mpc-aio/src/protocol/garble/mod.rs b/mpc-aio/src/protocol/garble/mod.rs index a9c5ab61f..757fd6753 100644 --- a/mpc-aio/src/protocol/garble/mod.rs +++ b/mpc-aio/src/protocol/garble/mod.rs @@ -1,3 +1,4 @@ +pub mod backend; pub mod exec; mod label; @@ -28,6 +29,8 @@ pub enum GCError { LabelOTError(#[from] label::WireLabelError), #[error("Received unexpected message: {0:?}")] Unexpected(GarbleMessage), + #[error("backend error")] + BackendError(String), } #[async_trait] @@ -78,45 +81,3 @@ pub trait Execute: ExecuteWithLabels { } impl Execute for T where T: ExecuteWithLabels {} - -#[cfg(feature = "mock")] -mod mock { - use super::*; - use aes::{Aes128, NewBlockCipher}; - - pub struct MockGenerator; - pub struct MockEvaluator; - - #[async_trait] - impl Generator for MockGenerator { - async fn generate( - &mut self, - circ: Arc, - delta: Delta, - input_labels: &[InputLabels], - ) -> Result, GCError> { - let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); - Ok(GarbledCircuit::generate( - &cipher, - circ, - delta, - input_labels, - )?) - } - } - - #[async_trait] - impl Evaluator for MockEvaluator { - async fn evaluate( - &mut self, - circ: GarbledCircuit, - input_labels: &[InputLabels], - ) -> Result, GCError> { - let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); - Ok(circ.evaluate(&cipher, input_labels)?) - } - } -} - -#[cfg(feature = "mock")] -pub use mock::*;