Rayon Garbler (#115)

* rayon garbler

* comment RayonGarbler

* rename some things, fix import

* remove redundant ref

* move and rename mock backend

* rename error
This commit is contained in:
sinu.eth
2022-11-16 10:02:07 -08:00
committed by GitHub
parent c559b6083e
commit 34f7e867c9
5 changed files with 172 additions and 96 deletions

View File

@@ -37,6 +37,7 @@ tokio = { workspace = true, features = [
mockall = "0.11"
async-stream = "0.3"
aes = { workspace = true, optional = true }
rayon = { workspace = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { workspace = true, features = ["js"] }

View File

@@ -0,0 +1,52 @@
mod rayon;
pub use self::rayon::RayonBackend;
#[cfg(feature = "mock")]
mod mock {
use std::sync::Arc;
use aes::{Aes128, NewBlockCipher};
use async_trait::async_trait;
use crate::protocol::garble::{Evaluator, GCError, Generator};
use mpc_circuits::Circuit;
use mpc_core::garble::{
Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair,
};
pub struct MockBackend;
#[async_trait]
impl Generator for MockBackend {
async fn generate(
&mut self,
circ: Arc<Circuit>,
delta: Delta,
input_labels: &[InputLabels<WireLabelPair>],
) -> Result<GarbledCircuit<Full>, GCError> {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
Ok(GarbledCircuit::generate(
&cipher,
circ,
delta,
input_labels,
)?)
}
}
#[async_trait]
impl Evaluator for MockBackend {
async fn evaluate(
&mut self,
circ: GarbledCircuit<Partial>,
input_labels: &[InputLabels<WireLabel>],
) -> Result<GarbledCircuit<Evaluated>, GCError> {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
Ok(circ.evaluate(&cipher, input_labels)?)
}
}
}
#[cfg(feature = "mock")]
pub use mock::MockBackend;

View File

@@ -0,0 +1,89 @@
use std::sync::Arc;
use aes::{Aes128, NewBlockCipher};
use async_trait::async_trait;
use futures::channel::oneshot;
use mpc_circuits::Circuit;
use mpc_core::garble::{
Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair,
};
use crate::protocol::garble::{Evaluator, GCError, Generator};
/// Garbler backend using Rayon to garble and evaluate circuits asynchronously and in parallel
pub struct RayonBackend;
#[async_trait]
impl Generator for RayonBackend {
async fn generate(
&mut self,
circ: Arc<Circuit>,
delta: Delta,
input_labels: &[InputLabels<WireLabelPair>],
) -> Result<GarbledCircuit<Full>, GCError> {
let (sender, receiver) = oneshot::channel();
let input_labels = input_labels.to_vec();
rayon::spawn(move || {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
let gc = GarbledCircuit::generate(&cipher, circ, delta, &input_labels)
.map_err(GCError::from);
_ = sender.send(gc);
});
receiver
.await
.map_err(|_| GCError::BackendError("channel error".to_string()))?
}
}
#[async_trait]
impl Evaluator for RayonBackend {
async fn evaluate(
&mut self,
circ: GarbledCircuit<Partial>,
input_labels: &[InputLabels<WireLabel>],
) -> Result<GarbledCircuit<Evaluated>, GCError> {
let (sender, receiver) = oneshot::channel();
let input_labels = input_labels.to_vec();
rayon::spawn(move || {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
let ev = circ.evaluate(&cipher, &input_labels).map_err(GCError::from);
_ = sender.send(ev);
});
receiver
.await
.map_err(|_| GCError::BackendError("channel error".to_string()))?
}
}
#[cfg(test)]
mod test {
use mpc_circuits::ADDER_64;
use rand::thread_rng;
use super::*;
#[tokio::test]
async fn test_rayon_garbler() {
let circ = Arc::new(Circuit::load_bytes(ADDER_64).unwrap());
let (input_labels, delta) = InputLabels::generate(&mut thread_rng(), &circ, None);
let gc = RayonBackend
.generate(circ.clone(), delta, &input_labels)
.await
.unwrap();
let input_labels = vec![
input_labels[0]
.select(&circ.input(0).unwrap().to_value(0u64).unwrap())
.unwrap(),
input_labels[1]
.select(&circ.input(1).unwrap().to_value(0u64).unwrap())
.unwrap(),
];
let _ = RayonBackend
.evaluate(gc.to_evaluator(&[], true, false), &input_labels)
.await
.unwrap();
}
}

View File

@@ -20,38 +20,28 @@ use mpc_core::garble::{
};
use utils_aio::expect_msg_or_err;
pub struct DualExLeader<G, E, S, R>
pub struct DualExLeader<B, S, R>
where
G: Generator,
E: Evaluator,
B: Generator + Evaluator,
S: WireLabelOTSend,
R: WireLabelOTReceive,
{
channel: GarbleChannel,
generator: G,
evaluator: E,
backend: B,
label_sender: S,
label_receiver: R,
}
impl<G, E, S, R> DualExLeader<G, E, S, R>
impl<B, S, R> DualExLeader<B, S, R>
where
G: Generator + Send,
E: Evaluator + Send,
B: Generator + Evaluator + Send,
S: WireLabelOTSend + Send,
R: WireLabelOTReceive + Send,
{
pub fn new(
channel: GarbleChannel,
generator: G,
evaluator: E,
label_sender: S,
label_receiver: R,
) -> Self {
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
Self {
channel,
generator,
evaluator,
backend,
label_sender,
label_receiver,
}
@@ -59,10 +49,9 @@ where
}
#[async_trait]
impl<G, E, S, R> ExecuteWithLabels for DualExLeader<G, E, S, R>
impl<B, S, R> ExecuteWithLabels for DualExLeader<B, S, R>
where
G: Generator + Send,
E: Evaluator + Send,
B: Generator + Evaluator + Send,
S: WireLabelOTSend + Send,
R: WireLabelOTReceive + Send,
{
@@ -75,8 +64,8 @@ where
) -> Result<GarbledCircuit<Evaluated>, GCError> {
let leader = core::DualExLeader::new(circ.clone());
let full_gc = self
.generator
.generate(circ.clone(), delta, input_labels)
.backend
.generate(circ.clone(), delta, &input_labels)
.await?;
let (partial_gc, leader) = leader.from_full_circuit(inputs, full_gc)?;
@@ -106,7 +95,7 @@ where
let gc_ev = GarbledCircuit::<Partial>::from_msg(circ, msg)?;
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?;
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
let leader = leader.from_evaluated_circuit(evaluated_gc)?;
let (commit, leader) = leader.commit();
@@ -132,38 +121,28 @@ where
}
}
pub struct DualExFollower<G, E, S, R>
pub struct DualExFollower<B, S, R>
where
G: Generator,
E: Evaluator,
B: Generator + Evaluator,
S: WireLabelOTSend,
R: WireLabelOTReceive,
{
channel: GarbleChannel,
generator: G,
evaluator: E,
backend: B,
label_sender: S,
label_receiver: R,
}
impl<G, E, S, R> DualExFollower<G, E, S, R>
impl<B, S, R> DualExFollower<B, S, R>
where
G: Generator + Send,
E: Evaluator + Send,
B: Generator + Evaluator + Send,
S: WireLabelOTSend + Send,
R: WireLabelOTReceive + Send,
{
pub fn new(
channel: GarbleChannel,
generator: G,
evaluator: E,
label_sender: S,
label_receiver: R,
) -> Self {
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
Self {
channel,
generator,
evaluator,
backend,
label_sender,
label_receiver,
}
@@ -171,10 +150,9 @@ where
}
#[async_trait]
impl<G, E, S, R> ExecuteWithLabels for DualExFollower<G, E, S, R>
impl<B, S, R> ExecuteWithLabels for DualExFollower<B, S, R>
where
G: Generator + Send,
E: Evaluator + Send,
B: Generator + Evaluator + Send,
S: WireLabelOTSend + Send,
R: WireLabelOTReceive + Send,
{
@@ -187,8 +165,8 @@ where
) -> Result<GarbledCircuit<Evaluated>, GCError> {
let follower = core::DualExFollower::new(circ.clone());
let full_gc = self
.generator
.generate(circ.clone(), delta, input_labels)
.backend
.generate(circ.clone(), delta, &input_labels)
.await?;
let (partial_gc, follower) = follower.from_full_circuit(inputs, full_gc)?;
@@ -217,7 +195,7 @@ where
let gc_ev = GarbledCircuit::<Partial>::from_msg(circ, msg)?;
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?;
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
let follower = follower.from_evaluated_circuit(evaluated_gc)?;
let msg = expect_msg_or_err!(
@@ -247,10 +225,7 @@ where
#[cfg(feature = "mock")]
mod mock {
use super::*;
use crate::protocol::{
garble::mock::{MockEvaluator, MockGenerator},
ot::mock::mock_ot_pair,
};
use crate::protocol::{garble::backend::MockBackend, ot::mock::mock_ot_pair};
use utils_aio::duplex::DuplexChannel;
pub fn mock_dualex_pair() -> (impl ExecuteWithLabels, impl ExecuteWithLabels) {
@@ -260,15 +235,13 @@ mod mock {
let leader = DualExLeader::new(
Box::new(leader_channel),
MockGenerator,
MockEvaluator,
MockBackend,
leader_sender,
leader_receiver,
);
let follower = DualExFollower::new(
Box::new(follower_channel),
MockGenerator,
MockEvaluator,
MockBackend,
follower_sender,
follower_receiver,
);

View File

@@ -1,3 +1,4 @@
pub mod backend;
pub mod exec;
mod label;
@@ -28,6 +29,8 @@ pub enum GCError {
LabelOTError(#[from] label::WireLabelError),
#[error("Received unexpected message: {0:?}")]
Unexpected(GarbleMessage),
#[error("backend error")]
BackendError(String),
}
#[async_trait]
@@ -78,45 +81,3 @@ pub trait Execute: ExecuteWithLabels {
}
impl<T> Execute for T where T: ExecuteWithLabels {}
#[cfg(feature = "mock")]
mod mock {
use super::*;
use aes::{Aes128, NewBlockCipher};
pub struct MockGenerator;
pub struct MockEvaluator;
#[async_trait]
impl Generator for MockGenerator {
async fn generate(
&mut self,
circ: Arc<Circuit>,
delta: Delta,
input_labels: &[InputLabels<WireLabelPair>],
) -> Result<GarbledCircuit<Full>, GCError> {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
Ok(GarbledCircuit::generate(
&cipher,
circ,
delta,
input_labels,
)?)
}
}
#[async_trait]
impl Evaluator for MockEvaluator {
async fn evaluate(
&mut self,
circ: GarbledCircuit<Partial>,
input_labels: &[InputLabels<WireLabel>],
) -> Result<GarbledCircuit<Evaluated>, GCError> {
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
Ok(circ.evaluate(&cipher, input_labels)?)
}
}
}
#[cfg(feature = "mock")]
pub use mock::*;