mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-04-28 03:00:14 -04:00
Rayon Garbler (#115)
* rayon garbler * comment RayonGarbler * rename some things, fix import * remove redundant ref * move and rename mock backend * rename error
This commit is contained in:
@@ -37,6 +37,7 @@ tokio = { workspace = true, features = [
|
||||
mockall = "0.11"
|
||||
async-stream = "0.3"
|
||||
aes = { workspace = true, optional = true }
|
||||
rayon = { workspace = true }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { workspace = true, features = ["js"] }
|
||||
|
||||
52
mpc-aio/src/protocol/garble/backend/mod.rs
Normal file
52
mpc-aio/src/protocol/garble/backend/mod.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
mod rayon;
|
||||
|
||||
pub use self::rayon::RayonBackend;
|
||||
|
||||
#[cfg(feature = "mock")]
|
||||
mod mock {
|
||||
use std::sync::Arc;
|
||||
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::protocol::garble::{Evaluator, GCError, Generator};
|
||||
use mpc_circuits::Circuit;
|
||||
use mpc_core::garble::{
|
||||
Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair,
|
||||
};
|
||||
|
||||
pub struct MockBackend;
|
||||
|
||||
#[async_trait]
|
||||
impl Generator for MockBackend {
|
||||
async fn generate(
|
||||
&mut self,
|
||||
circ: Arc<Circuit>,
|
||||
delta: Delta,
|
||||
input_labels: &[InputLabels<WireLabelPair>],
|
||||
) -> Result<GarbledCircuit<Full>, GCError> {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
Ok(GarbledCircuit::generate(
|
||||
&cipher,
|
||||
circ,
|
||||
delta,
|
||||
input_labels,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Evaluator for MockBackend {
|
||||
async fn evaluate(
|
||||
&mut self,
|
||||
circ: GarbledCircuit<Partial>,
|
||||
input_labels: &[InputLabels<WireLabel>],
|
||||
) -> Result<GarbledCircuit<Evaluated>, GCError> {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
Ok(circ.evaluate(&cipher, input_labels)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mock")]
|
||||
pub use mock::MockBackend;
|
||||
89
mpc-aio/src/protocol/garble/backend/rayon.rs
Normal file
89
mpc-aio/src/protocol/garble/backend/rayon.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
use async_trait::async_trait;
|
||||
use futures::channel::oneshot;
|
||||
|
||||
use mpc_circuits::Circuit;
|
||||
use mpc_core::garble::{
|
||||
Delta, Evaluated, Full, GarbledCircuit, InputLabels, Partial, WireLabel, WireLabelPair,
|
||||
};
|
||||
|
||||
use crate::protocol::garble::{Evaluator, GCError, Generator};
|
||||
|
||||
/// Garbler backend using Rayon to garble and evaluate circuits asynchronously and in parallel
|
||||
pub struct RayonBackend;
|
||||
|
||||
#[async_trait]
|
||||
impl Generator for RayonBackend {
|
||||
async fn generate(
|
||||
&mut self,
|
||||
circ: Arc<Circuit>,
|
||||
delta: Delta,
|
||||
input_labels: &[InputLabels<WireLabelPair>],
|
||||
) -> Result<GarbledCircuit<Full>, GCError> {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
let input_labels = input_labels.to_vec();
|
||||
rayon::spawn(move || {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
let gc = GarbledCircuit::generate(&cipher, circ, delta, &input_labels)
|
||||
.map_err(GCError::from);
|
||||
_ = sender.send(gc);
|
||||
});
|
||||
receiver
|
||||
.await
|
||||
.map_err(|_| GCError::BackendError("channel error".to_string()))?
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Evaluator for RayonBackend {
|
||||
async fn evaluate(
|
||||
&mut self,
|
||||
circ: GarbledCircuit<Partial>,
|
||||
input_labels: &[InputLabels<WireLabel>],
|
||||
) -> Result<GarbledCircuit<Evaluated>, GCError> {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
let input_labels = input_labels.to_vec();
|
||||
rayon::spawn(move || {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
let ev = circ.evaluate(&cipher, &input_labels).map_err(GCError::from);
|
||||
_ = sender.send(ev);
|
||||
});
|
||||
receiver
|
||||
.await
|
||||
.map_err(|_| GCError::BackendError("channel error".to_string()))?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use mpc_circuits::ADDER_64;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rayon_garbler() {
|
||||
let circ = Arc::new(Circuit::load_bytes(ADDER_64).unwrap());
|
||||
let (input_labels, delta) = InputLabels::generate(&mut thread_rng(), &circ, None);
|
||||
let gc = RayonBackend
|
||||
.generate(circ.clone(), delta, &input_labels)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let input_labels = vec![
|
||||
input_labels[0]
|
||||
.select(&circ.input(0).unwrap().to_value(0u64).unwrap())
|
||||
.unwrap(),
|
||||
input_labels[1]
|
||||
.select(&circ.input(1).unwrap().to_value(0u64).unwrap())
|
||||
.unwrap(),
|
||||
];
|
||||
|
||||
let _ = RayonBackend
|
||||
.evaluate(gc.to_evaluator(&[], true, false), &input_labels)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -20,38 +20,28 @@ use mpc_core::garble::{
|
||||
};
|
||||
use utils_aio::expect_msg_or_err;
|
||||
|
||||
pub struct DualExLeader<G, E, S, R>
|
||||
pub struct DualExLeader<B, S, R>
|
||||
where
|
||||
G: Generator,
|
||||
E: Evaluator,
|
||||
B: Generator + Evaluator,
|
||||
S: WireLabelOTSend,
|
||||
R: WireLabelOTReceive,
|
||||
{
|
||||
channel: GarbleChannel,
|
||||
generator: G,
|
||||
evaluator: E,
|
||||
backend: B,
|
||||
label_sender: S,
|
||||
label_receiver: R,
|
||||
}
|
||||
|
||||
impl<G, E, S, R> DualExLeader<G, E, S, R>
|
||||
impl<B, S, R> DualExLeader<B, S, R>
|
||||
where
|
||||
G: Generator + Send,
|
||||
E: Evaluator + Send,
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
{
|
||||
pub fn new(
|
||||
channel: GarbleChannel,
|
||||
generator: G,
|
||||
evaluator: E,
|
||||
label_sender: S,
|
||||
label_receiver: R,
|
||||
) -> Self {
|
||||
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
|
||||
Self {
|
||||
channel,
|
||||
generator,
|
||||
evaluator,
|
||||
backend,
|
||||
label_sender,
|
||||
label_receiver,
|
||||
}
|
||||
@@ -59,10 +49,9 @@ where
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<G, E, S, R> ExecuteWithLabels for DualExLeader<G, E, S, R>
|
||||
impl<B, S, R> ExecuteWithLabels for DualExLeader<B, S, R>
|
||||
where
|
||||
G: Generator + Send,
|
||||
E: Evaluator + Send,
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
{
|
||||
@@ -75,8 +64,8 @@ where
|
||||
) -> Result<GarbledCircuit<Evaluated>, GCError> {
|
||||
let leader = core::DualExLeader::new(circ.clone());
|
||||
let full_gc = self
|
||||
.generator
|
||||
.generate(circ.clone(), delta, input_labels)
|
||||
.backend
|
||||
.generate(circ.clone(), delta, &input_labels)
|
||||
.await?;
|
||||
|
||||
let (partial_gc, leader) = leader.from_full_circuit(inputs, full_gc)?;
|
||||
@@ -106,7 +95,7 @@ where
|
||||
let gc_ev = GarbledCircuit::<Partial>::from_msg(circ, msg)?;
|
||||
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
|
||||
|
||||
let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?;
|
||||
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
|
||||
let leader = leader.from_evaluated_circuit(evaluated_gc)?;
|
||||
let (commit, leader) = leader.commit();
|
||||
|
||||
@@ -132,38 +121,28 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DualExFollower<G, E, S, R>
|
||||
pub struct DualExFollower<B, S, R>
|
||||
where
|
||||
G: Generator,
|
||||
E: Evaluator,
|
||||
B: Generator + Evaluator,
|
||||
S: WireLabelOTSend,
|
||||
R: WireLabelOTReceive,
|
||||
{
|
||||
channel: GarbleChannel,
|
||||
generator: G,
|
||||
evaluator: E,
|
||||
backend: B,
|
||||
label_sender: S,
|
||||
label_receiver: R,
|
||||
}
|
||||
|
||||
impl<G, E, S, R> DualExFollower<G, E, S, R>
|
||||
impl<B, S, R> DualExFollower<B, S, R>
|
||||
where
|
||||
G: Generator + Send,
|
||||
E: Evaluator + Send,
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
{
|
||||
pub fn new(
|
||||
channel: GarbleChannel,
|
||||
generator: G,
|
||||
evaluator: E,
|
||||
label_sender: S,
|
||||
label_receiver: R,
|
||||
) -> Self {
|
||||
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
|
||||
Self {
|
||||
channel,
|
||||
generator,
|
||||
evaluator,
|
||||
backend,
|
||||
label_sender,
|
||||
label_receiver,
|
||||
}
|
||||
@@ -171,10 +150,9 @@ where
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<G, E, S, R> ExecuteWithLabels for DualExFollower<G, E, S, R>
|
||||
impl<B, S, R> ExecuteWithLabels for DualExFollower<B, S, R>
|
||||
where
|
||||
G: Generator + Send,
|
||||
E: Evaluator + Send,
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
{
|
||||
@@ -187,8 +165,8 @@ where
|
||||
) -> Result<GarbledCircuit<Evaluated>, GCError> {
|
||||
let follower = core::DualExFollower::new(circ.clone());
|
||||
let full_gc = self
|
||||
.generator
|
||||
.generate(circ.clone(), delta, input_labels)
|
||||
.backend
|
||||
.generate(circ.clone(), delta, &input_labels)
|
||||
.await?;
|
||||
|
||||
let (partial_gc, follower) = follower.from_full_circuit(inputs, full_gc)?;
|
||||
@@ -217,7 +195,7 @@ where
|
||||
let gc_ev = GarbledCircuit::<Partial>::from_msg(circ, msg)?;
|
||||
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
|
||||
|
||||
let evaluated_gc = self.evaluator.evaluate(gc_ev, &labels_ev).await?;
|
||||
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
|
||||
let follower = follower.from_evaluated_circuit(evaluated_gc)?;
|
||||
|
||||
let msg = expect_msg_or_err!(
|
||||
@@ -247,10 +225,7 @@ where
|
||||
#[cfg(feature = "mock")]
|
||||
mod mock {
|
||||
use super::*;
|
||||
use crate::protocol::{
|
||||
garble::mock::{MockEvaluator, MockGenerator},
|
||||
ot::mock::mock_ot_pair,
|
||||
};
|
||||
use crate::protocol::{garble::backend::MockBackend, ot::mock::mock_ot_pair};
|
||||
use utils_aio::duplex::DuplexChannel;
|
||||
|
||||
pub fn mock_dualex_pair() -> (impl ExecuteWithLabels, impl ExecuteWithLabels) {
|
||||
@@ -260,15 +235,13 @@ mod mock {
|
||||
|
||||
let leader = DualExLeader::new(
|
||||
Box::new(leader_channel),
|
||||
MockGenerator,
|
||||
MockEvaluator,
|
||||
MockBackend,
|
||||
leader_sender,
|
||||
leader_receiver,
|
||||
);
|
||||
let follower = DualExFollower::new(
|
||||
Box::new(follower_channel),
|
||||
MockGenerator,
|
||||
MockEvaluator,
|
||||
MockBackend,
|
||||
follower_sender,
|
||||
follower_receiver,
|
||||
);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod backend;
|
||||
pub mod exec;
|
||||
mod label;
|
||||
|
||||
@@ -28,6 +29,8 @@ pub enum GCError {
|
||||
LabelOTError(#[from] label::WireLabelError),
|
||||
#[error("Received unexpected message: {0:?}")]
|
||||
Unexpected(GarbleMessage),
|
||||
#[error("backend error")]
|
||||
BackendError(String),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -78,45 +81,3 @@ pub trait Execute: ExecuteWithLabels {
|
||||
}
|
||||
|
||||
impl<T> Execute for T where T: ExecuteWithLabels {}
|
||||
|
||||
#[cfg(feature = "mock")]
|
||||
mod mock {
|
||||
use super::*;
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
|
||||
pub struct MockGenerator;
|
||||
pub struct MockEvaluator;
|
||||
|
||||
#[async_trait]
|
||||
impl Generator for MockGenerator {
|
||||
async fn generate(
|
||||
&mut self,
|
||||
circ: Arc<Circuit>,
|
||||
delta: Delta,
|
||||
input_labels: &[InputLabels<WireLabelPair>],
|
||||
) -> Result<GarbledCircuit<Full>, GCError> {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
Ok(GarbledCircuit::generate(
|
||||
&cipher,
|
||||
circ,
|
||||
delta,
|
||||
input_labels,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Evaluator for MockEvaluator {
|
||||
async fn evaluate(
|
||||
&mut self,
|
||||
circ: GarbledCircuit<Partial>,
|
||||
input_labels: &[InputLabels<WireLabel>],
|
||||
) -> Result<GarbledCircuit<Evaluated>, GCError> {
|
||||
let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap();
|
||||
Ok(circ.evaluate(&cipher, input_labels)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mock")]
|
||||
pub use mock::*;
|
||||
|
||||
Reference in New Issue
Block a user