From 9dfbfd2d2db8d1f0b11e2ea388b2e634cb89ea0a Mon Sep 17 00:00:00 2001 From: sinuio <> Date: Mon, 4 Apr 2022 22:50:54 -0700 Subject: [PATCH] refactor ot aio --- mpc-aio/examples/{ot_ws.rs => ote_ws.rs} | 12 ++-- mpc-aio/src/garble/errors.rs | 4 +- mpc-aio/src/garble/evaluator.rs | 4 +- mpc-aio/src/garble/generator.rs | 4 +- mpc-aio/src/ot/base/mod.rs | 6 ++ mpc-aio/src/ot/base/receiver.rs | 71 ++++++++++++++++++++ mpc-aio/src/ot/base/sender.rs | 66 ++++++++++++++++++ mpc-aio/src/ot/errors.rs | 16 +++-- mpc-aio/src/ot/extension/mod.rs | 7 ++ mpc-aio/src/ot/extension/receiver.rs | 81 ++++++++++++++++++++++ mpc-aio/src/ot/{ => extension}/sender.rs | 60 ++++++++--------- mpc-aio/src/ot/mod.rs | 11 +-- mpc-aio/src/ot/receiver.rs | 85 ------------------------ mpc-core/benches/ot.rs | 2 +- mpc-core/examples/ot.rs | 6 +- mpc-core/src/ot/base/mod.rs | 10 +-- mpc-core/src/ot/base/receiver.rs | 15 +++-- mpc-core/src/ot/extension/receiver.rs | 21 +++--- mpc-core/src/ot/extension/sender.rs | 2 +- 19 files changed, 317 insertions(+), 166 deletions(-) rename mpc-aio/examples/{ot_ws.rs => ote_ws.rs} (80%) create mode 100644 mpc-aio/src/ot/base/mod.rs create mode 100644 mpc-aio/src/ot/base/receiver.rs create mode 100644 mpc-aio/src/ot/base/sender.rs create mode 100644 mpc-aio/src/ot/extension/mod.rs create mode 100644 mpc-aio/src/ot/extension/receiver.rs rename mpc-aio/src/ot/{ => extension}/sender.rs (53%) delete mode 100644 mpc-aio/src/ot/receiver.rs diff --git a/mpc-aio/examples/ot_ws.rs b/mpc-aio/examples/ote_ws.rs similarity index 80% rename from mpc-aio/examples/ot_ws.rs rename to mpc-aio/examples/ote_ws.rs index 8e789b62d..a77da7889 100644 --- a/mpc-aio/examples/ot_ws.rs +++ b/mpc-aio/examples/ote_ws.rs @@ -1,5 +1,5 @@ -use mpc_aio::ot::{OtReceive, OtReceiver, OtSend, OtSender}; -use mpc_core::ot::{ChaChaAesOtReceiver, ChaChaAesOtSender, OtMessage}; +use mpc_aio::ot::{ExtOTReceive, ExtOTSend, ExtReceiver, ExtSender, Message}; +use mpc_core::ot::{ExtReceiverCore, ExtSenderCore}; use mpc_core::proto; use mpc_core::Block; use tokio; @@ -24,10 +24,10 @@ async fn ot_receive(stream: UnixStream) { let stream = Framed::new( ws, - ProstCodecDelimited::::default(), + ProstCodecDelimited::::default(), ); - let mut receiver = OtReceiver::new(ChaChaAesOtReceiver::default(), stream); + let mut receiver = ExtReceiver::new(ExtReceiverCore::default(), stream); let choice = vec![false, false, true]; @@ -52,10 +52,10 @@ async fn ot_send(stream: UnixStream) { let stream = Framed::new( ws, - ProstCodecDelimited::::default(), + ProstCodecDelimited::::default(), ); - let mut sender = OtSender::new(ChaChaAesOtSender::default(), stream); + let mut sender = ExtSender::new(ExtSenderCore::default(), stream); let messages = [ [Block::new(0), Block::new(1)], diff --git a/mpc-aio/src/garble/errors.rs b/mpc-aio/src/garble/errors.rs index 9348495af..0617e8e4c 100644 --- a/mpc-aio/src/garble/errors.rs +++ b/mpc-aio/src/garble/errors.rs @@ -4,7 +4,7 @@ use mpc_core::garble::{ }; use thiserror::Error; -use crate::ot::OtError; +use crate::ot::OTError; #[derive(Debug, Error)] pub enum GarbleError { @@ -13,7 +13,7 @@ pub enum GarbleError { #[error("Encountered error during evaluation: {0}")] EvaluatorError(#[from] EvaluatorError), #[error("Encountered OT error: {0}")] - OtError(#[from] OtError), + OTError(#[from] OTError), #[error("Received unexpected message: {0:?}")] Unexpected(GarbleMessage), #[error("Encountered IO error: {0}")] diff --git a/mpc-aio/src/garble/evaluator.rs b/mpc-aio/src/garble/evaluator.rs index fcad7c9fc..552c33790 100644 --- a/mpc-aio/src/garble/evaluator.rs +++ b/mpc-aio/src/garble/evaluator.rs @@ -1,5 +1,5 @@ use super::GarbleError; -use crate::ot::receiver::OtReceive; +use crate::ot::OTReceive; use mpc_core::circuit::{Circuit, CircuitInput}; use mpc_core::garble::circuit::InputLabel; use mpc_core::garble::evaluator::GarbledCircuitEvaluator; @@ -29,7 +29,7 @@ where pub async fn evaluate( &mut self, - ot: &mut impl OtReceive, + ot: &mut impl OTReceive, circ: &Circuit, ev: &V, inputs: &Vec, diff --git a/mpc-aio/src/garble/generator.rs b/mpc-aio/src/garble/generator.rs index 62829f2f1..675c87033 100644 --- a/mpc-aio/src/garble/generator.rs +++ b/mpc-aio/src/garble/generator.rs @@ -1,5 +1,5 @@ use super::GarbleError; -use crate::ot::sender::OtSend; +use crate::ot::OTSend; use mpc_core::circuit::{Circuit, CircuitInput}; use mpc_core::garble::generator::GarbledCircuitGenerator; use mpc_core::garble::GarbleMessage; @@ -29,7 +29,7 @@ where pub async fn garble( &mut self, - ot: &mut impl OtSend, + ot: &mut impl OTSend, circ: &Circuit, gen: &G, inputs: &Vec, diff --git a/mpc-aio/src/ot/base/mod.rs b/mpc-aio/src/ot/base/mod.rs new file mode 100644 index 000000000..b746e048c --- /dev/null +++ b/mpc-aio/src/ot/base/mod.rs @@ -0,0 +1,6 @@ +pub mod receiver; +pub mod sender; + +pub use super::errors::*; +pub use receiver::{OTReceive, Receiver}; +pub use sender::{OTSend, Sender}; diff --git a/mpc-aio/src/ot/base/receiver.rs b/mpc-aio/src/ot/base/receiver.rs new file mode 100644 index 000000000..917ec2cf5 --- /dev/null +++ b/mpc-aio/src/ot/base/receiver.rs @@ -0,0 +1,71 @@ +use super::OTError; +use async_trait::async_trait; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use mpc_core::ot::{Message, ReceiveCore}; +use mpc_core::Block; +use std::io::Error as IOError; +use std::io::ErrorKind; +use tracing::{instrument, trace}; + +pub struct Receiver { + ot: OT, + stream: S, +} + +#[async_trait] +pub trait OTReceive { + async fn receive(&mut self, choice: &[bool]) -> Result, OTError>; +} + +impl< + OT: ReceiveCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > Receiver +where + OTError: From<>::Error>, + OTError: From, +{ + pub fn new(ot: OT, stream: S) -> Self { + Self { ot, stream } + } +} + +#[async_trait] +impl< + OT: ReceiveCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > OTReceive for Receiver +where + OTError: From<>::Error>, + OTError: From, +{ + #[instrument(skip(self, choice))] + async fn receive(&mut self, choice: &[bool]) -> Result, OTError> { + let setup = match self.stream.next().await { + Some(Ok(Message::SenderSetup(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + trace!("Received SenderSetup"); + + let setup = self.ot.setup(choice, setup)?; + + trace!("Sending ReceiverSetup"); + self.stream.send(Message::ReceiverSetup(setup)).await?; + + let payload = match self.stream.next().await { + Some(Ok(Message::SenderPayload(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + trace!("Received SenderPayload"); + + let values = self.ot.receive(payload)?; + + Ok(values) + } +} diff --git a/mpc-aio/src/ot/base/sender.rs b/mpc-aio/src/ot/base/sender.rs new file mode 100644 index 000000000..98091b078 --- /dev/null +++ b/mpc-aio/src/ot/base/sender.rs @@ -0,0 +1,66 @@ +use super::OTError; +use async_trait::async_trait; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use mpc_core::ot::{Message, SendCore}; +use mpc_core::Block; +use std::io::Error as IOError; +use std::io::ErrorKind; +use tracing::{instrument, trace}; + +pub struct Sender { + ot: OT, + stream: S, +} + +#[async_trait] +pub trait OTSend { + async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError>; +} + +impl< + OT: SendCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > Sender +where + OTError: From<>::Error>, + OTError: From, +{ + pub fn new(ot: OT, stream: S) -> Self { + Self { ot, stream } + } +} + +#[async_trait] +impl< + OT: SendCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > OTSend for Sender +where + OTError: From<>::Error>, + OTError: From, +{ + #[instrument(skip(self, payload))] + async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError> { + let setup = self.ot.setup(); + + trace!("Sending SenderSetup"); + self.stream.send(Message::SenderSetup(setup)).await?; + + let setup = match self.stream.next().await { + Some(Ok(Message::ReceiverSetup(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + trace!("Received ReceiverSetup"); + + let payload = self.ot.send(payload, setup)?; + + self.stream.send(Message::SenderPayload(payload)).await?; + trace!("Sending SenderPayload"); + + Ok(()) + } +} diff --git a/mpc-aio/src/ot/errors.rs b/mpc-aio/src/ot/errors.rs index e13d302a2..80fb3325b 100644 --- a/mpc-aio/src/ot/errors.rs +++ b/mpc-aio/src/ot/errors.rs @@ -1,16 +1,20 @@ -use mpc_core::ot::errors::{OtReceiverCoreError, OtSenderCoreError}; -use mpc_core::ot::OtMessage; +use mpc_core::ot::Message; +use mpc_core::ot::{ExtReceiverCoreError, ExtSenderCoreError, ReceiverCoreError, SenderCoreError}; use thiserror::Error; /// Errors that may occur when using AsyncOTSender #[derive(Debug, Error)] -pub enum OtError { +pub enum OTError { #[error("OT sender core error: {0}")] - OtSenderCoreError(#[from] OtSenderCoreError), + SenderCoreError(#[from] SenderCoreError), #[error("OT receiver core error: {0}")] - OtReceiverCoreError(#[from] OtReceiverCoreError), + ReceiverCoreError(#[from] ReceiverCoreError), + #[error("OT sender core error: {0}")] + ExtSenderCoreError(#[from] ExtSenderCoreError), + #[error("OT receiver core error: {0}")] + ExtReceiverCoreError(#[from] ExtReceiverCoreError), #[error("IO error: {0}")] IOError(#[from] std::io::Error), #[error("Received unexpected message: {0:?}")] - Unexpected(OtMessage), + Unexpected(Message), } diff --git a/mpc-aio/src/ot/extension/mod.rs b/mpc-aio/src/ot/extension/mod.rs new file mode 100644 index 000000000..e0d6bbefa --- /dev/null +++ b/mpc-aio/src/ot/extension/mod.rs @@ -0,0 +1,7 @@ +pub mod receiver; +pub mod sender; + +use super::errors::*; + +pub use receiver::{ExtOTReceive, ExtReceiver}; +pub use sender::{ExtOTSend, ExtSender}; diff --git a/mpc-aio/src/ot/extension/receiver.rs b/mpc-aio/src/ot/extension/receiver.rs new file mode 100644 index 000000000..3bb7af6a4 --- /dev/null +++ b/mpc-aio/src/ot/extension/receiver.rs @@ -0,0 +1,81 @@ +use super::OTError; +use async_trait::async_trait; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use mpc_core::ot::{ExtReceiveCore, Message}; +use mpc_core::Block; +use std::io::Error as IOError; +use std::io::ErrorKind; +use tracing::{instrument, trace}; + +pub struct ExtReceiver { + ot: OT, + stream: S, +} + +#[async_trait] +pub trait ExtOTReceive { + async fn receive(&mut self, choice: &[bool]) -> Result, OTError>; +} + +impl< + OT: ExtReceiveCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > ExtReceiver +where + OTError: From<>::Error>, + OTError: From, +{ + pub fn new(ot: OT, stream: S) -> Self { + Self { ot, stream } + } +} + +#[async_trait] +impl< + OT: ExtReceiveCore + Send, + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + > ExtOTReceive for ExtReceiver +where + OTError: From<>::Error>, + OTError: From, +{ + #[instrument(skip(self, choice))] + async fn receive(&mut self, choice: &[bool]) -> Result, OTError> { + let base_setup = self.ot.base_setup()?; + + trace!("Sending SenderSetup"); + self.stream.send(Message::SenderSetup(base_setup)).await?; + + let base_receiver_setup = match self.stream.next().await { + Some(Ok(Message::ReceiverSetup(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + trace!("Received ReceiverSetup"); + + let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?; + + trace!("Sending SenderPayload"); + self.stream.send(Message::SenderPayload(payload)).await?; + + let setup = self.ot.extension_setup(choice)?; + + trace!("Sending ExtReceiverSetup"); + self.stream.send(Message::ExtReceiverSetup(setup)).await?; + + let payload = match self.stream.next().await { + Some(Ok(Message::ExtSenderPayload(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + trace!("Received ExtSenderPayload"); + + let values = self.ot.receive(choice, payload.try_into().unwrap())?; + + Ok(values) + } +} diff --git a/mpc-aio/src/ot/sender.rs b/mpc-aio/src/ot/extension/sender.rs similarity index 53% rename from mpc-aio/src/ot/sender.rs rename to mpc-aio/src/ot/extension/sender.rs index d90f8cbf8..1a89b4928 100644 --- a/mpc-aio/src/ot/sender.rs +++ b/mpc-aio/src/ot/extension/sender.rs @@ -1,30 +1,30 @@ -use super::errors::OtError; +use super::OTError; use async_trait::async_trait; use futures_util::{Sink, SinkExt, Stream, StreamExt}; -use mpc_core::ot::{OtMessage, OtSendCore}; +use mpc_core::ot::{ExtSendCore, Message}; use mpc_core::Block; use std::io::Error as IOError; use std::io::ErrorKind; use tracing::{instrument, trace}; -pub struct OtSender { +pub struct ExtSender { ot: OT, stream: S, } #[async_trait] -pub trait OtSend { - async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OtError>; +pub trait ExtOTSend { + async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError>; } impl< - OT: OtSendCore + Send, - S: Sink + Stream> + Send + Unpin, + OT: ExtSendCore + Send, + S: Sink + Stream> + Send + Unpin, E: std::fmt::Debug, - > OtSender + > ExtSender where - OtError: From<>::Error>, - OtError: From, + OTError: From<>::Error>, + OTError: From, { pub fn new(ot: OT, stream: S) -> Self { Self { ot, stream } @@ -33,55 +33,53 @@ where #[async_trait] impl< - OT: OtSendCore + Send, - S: Sink + Stream> + Send + Unpin, + OT: ExtSendCore + Send, + S: Sink + Stream> + Send + Unpin, E: std::fmt::Debug, - > OtSend for OtSender + > ExtOTSend for ExtSender where - OtError: From<>::Error>, - OtError: From, + OTError: From<>::Error>, + OTError: From, { #[instrument(skip(self, payload))] - async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OtError> { + async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError> { let base_sender_setup = match self.stream.next().await { - Some(Ok(OtMessage::BaseSenderSetup(m))) => m, - Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Ok(Message::SenderSetup(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, }; - trace!("Received BaseOtSenderSetup"); + trace!("Received SenderSetup"); let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?; - trace!("Sending BaseOtReceiverSetup"); - self.stream - .send(OtMessage::BaseReceiverSetup(base_setup)) - .await?; + trace!("Sending ReceiverSetup"); + self.stream.send(Message::ReceiverSetup(base_setup)).await?; let base_payload = match self.stream.next().await { - Some(Ok(OtMessage::BaseSenderPayload(m))) => m, - Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Ok(Message::SenderPayload(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, }; - trace!("Received BaseOtSenderPayload"); + trace!("Received SenderPayload"); self.ot.base_receive(base_payload.try_into().unwrap())?; let extension_receiver_setup = match self.stream.next().await { - Some(Ok(OtMessage::ReceiverSetup(m))) => m, - Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Ok(Message::ExtReceiverSetup(m))) => m, + Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, }; - trace!("Received OtReceiverSetup"); + trace!("Received ExtReceiverSetup"); self.ot .extension_setup(extension_receiver_setup.try_into().unwrap())?; let payload = self.ot.send(payload)?; - self.stream.send(OtMessage::SenderPayload(payload)).await?; - trace!("Sending OtSenderPayload"); + self.stream.send(Message::ExtSenderPayload(payload)).await?; + trace!("Sending ExtSenderPayload"); Ok(()) } diff --git a/mpc-aio/src/ot/mod.rs b/mpc-aio/src/ot/mod.rs index d001426c2..282cc10b4 100644 --- a/mpc-aio/src/ot/mod.rs +++ b/mpc-aio/src/ot/mod.rs @@ -1,7 +1,8 @@ +pub mod base; pub mod errors; -pub mod receiver; -pub mod sender; +pub mod extension; -pub use errors::OtError; -pub use receiver::{OtReceive, OtReceiver}; -pub use sender::{OtSend, OtSender}; +pub use base::{OTReceive, OTSend, Receiver, Sender}; +pub use errors::OTError; +pub use extension::{ExtOTReceive, ExtOTSend, ExtReceiver, ExtSender}; +pub use mpc_core::ot::Message; diff --git a/mpc-aio/src/ot/receiver.rs b/mpc-aio/src/ot/receiver.rs deleted file mode 100644 index 62bdf9183..000000000 --- a/mpc-aio/src/ot/receiver.rs +++ /dev/null @@ -1,85 +0,0 @@ -use super::errors::OtError; -use async_trait::async_trait; -use futures_util::{Sink, SinkExt, Stream, StreamExt}; -use mpc_core::ot::{OtMessage, OtReceiveCore}; -use mpc_core::Block; -use std::io::Error as IOError; -use std::io::ErrorKind; -use tracing::{instrument, trace}; - -pub struct OtReceiver { - ot: OT, - stream: S, -} - -#[async_trait] -pub trait OtReceive { - async fn receive(&mut self, choice: &[bool]) -> Result, OtError>; -} - -impl< - OT: OtReceiveCore + Send, - S: Sink + Stream> + Send + Unpin, - E: std::fmt::Debug, - > OtReceiver -where - OtError: From<>::Error>, - OtError: From, -{ - pub fn new(ot: OT, stream: S) -> Self { - Self { ot, stream } - } -} - -#[async_trait] -impl< - OT: OtReceiveCore + Send, - S: Sink + Stream> + Send + Unpin, - E: std::fmt::Debug, - > OtReceive for OtReceiver -where - OtError: From<>::Error>, - OtError: From, -{ - #[instrument(skip(self, choice))] - async fn receive(&mut self, choice: &[bool]) -> Result, OtError> { - let base_setup = self.ot.base_setup()?; - - trace!("Sending BaseOtSenderSetup"); - self.stream - .send(OtMessage::BaseSenderSetup(base_setup)) - .await?; - - let base_receiver_setup = match self.stream.next().await { - Some(Ok(OtMessage::BaseReceiverSetup(m))) => m, - Some(Ok(m)) => return Err(OtError::Unexpected(m)), - Some(Err(e)) => return Err(e)?, - None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, - }; - trace!("Received BaseOtReceiverSetup"); - - let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?; - - trace!("Sending BaseOtSenderPayload"); - self.stream - .send(OtMessage::BaseSenderPayload(payload)) - .await?; - - let setup = self.ot.extension_setup(choice)?; - - trace!("Sending OtReceiverSetup"); - self.stream.send(OtMessage::ReceiverSetup(setup)).await?; - - let payload = match self.stream.next().await { - Some(Ok(OtMessage::SenderPayload(m))) => m, - Some(Ok(m)) => return Err(OtError::Unexpected(m)), - Some(Err(e)) => return Err(e)?, - None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, - }; - trace!("Received OtSenderPayload"); - - let values = self.ot.receive(choice, payload.try_into().unwrap())?; - - Ok(values) - } -} diff --git a/mpc-core/benches/ot.rs b/mpc-core/benches/ot.rs index a6de50925..70ca51ffd 100644 --- a/mpc-core/benches/ot.rs +++ b/mpc-core/benches/ot.rs @@ -19,7 +19,7 @@ fn criterion_benchmark(c: &mut Criterion) { let receiver_setup = receiver.setup(&choice, sender_setup).unwrap(); let send = sender.send(&s_inputs, receiver_setup).unwrap(); - let receive = receiver.receive(&choice, send).unwrap(); + let receive = receiver.receive(send).unwrap(); black_box(receive); }); }); diff --git a/mpc-core/examples/ot.rs b/mpc-core/examples/ot.rs index 1d6a2d77f..9a576fb99 100644 --- a/mpc-core/examples/ot.rs +++ b/mpc-core/examples/ot.rs @@ -25,11 +25,11 @@ pub fn main() { println!("Sender inputs: {:?}", &inputs); - // First the receiver creates a setup message and passes it to sender + // First the sender creates a setup message and passes it to sender let mut sender = SenderCore::default(); let setup = sender.setup(); - // Sender takes receiver's setup and creates its own setup message + // Receiver takes sender's setup and creates its own setup message let mut receiver = ReceiverCore::default(); let setup = receiver.setup(&choice, setup).unwrap(); @@ -37,7 +37,7 @@ pub fn main() { let payload = sender.send(&inputs, setup).unwrap(); // Receiver takes the encrypted inputs and is able to decrypt according to their choice bits - let received = receiver.receive(&choice, payload).unwrap(); + let received = receiver.receive(payload).unwrap(); println!("Transferred messages: {:?}", received); } diff --git a/mpc-core/src/ot/base/mod.rs b/mpc-core/src/ot/base/mod.rs index 3aba70eab..20d6bc18b 100644 --- a/mpc-core/src/ot/base/mod.rs +++ b/mpc-core/src/ot/base/mod.rs @@ -28,11 +28,7 @@ pub trait ReceiveCore { sender_setup: SenderSetup, ) -> Result; - fn receive( - &mut self, - choice: &[bool], - payload: SenderPayload, - ) -> Result, ReceiverCoreError>; + fn receive(&mut self, payload: SenderPayload) -> Result, ReceiverCoreError>; } #[cfg(test)] @@ -80,7 +76,7 @@ pub mod tests { let receiver_setup = receiver.setup(choice, sender_setup.clone()).unwrap(); let sender_payload = sender.send(values, receiver_setup.clone()).unwrap(); - let receiver_values = receiver.receive(choice, sender_payload.clone()).unwrap(); + let receiver_values = receiver.receive(sender_payload.clone()).unwrap(); Data { sender_setup, @@ -113,7 +109,7 @@ pub mod tests { let receiver_setup = receiver.setup(&choice, sender_setup).unwrap(); let send = sender.send(&s_inputs, receiver_setup).unwrap(); - let receive = receiver.receive(&choice, send).unwrap(); + let receive = receiver.receive(send).unwrap(); assert_eq!(expected, receive); } } diff --git a/mpc-core/src/ot/base/receiver.rs b/mpc-core/src/ot/base/receiver.rs index b76c89286..bdc535061 100644 --- a/mpc-core/src/ot/base/receiver.rs +++ b/mpc-core/src/ot/base/receiver.rs @@ -19,6 +19,7 @@ pub enum State { pub struct ReceiverCore { rng: R, hashes: Option>, + choice: Option>, state: State, } @@ -38,6 +39,7 @@ impl ReceiverCore { Self { rng, hashes: None, + choice: None, state: State::Initialized, } } @@ -68,23 +70,22 @@ impl ReceiveCore for ReceiverCore { }) .unzip(); self.hashes = Some(hashes); - + self.choice = Some(Vec::from(choice)); self.state = State::Setup; Ok(ReceiverSetup { keys }) } - fn receive( - &mut self, - choice: &[bool], - payload: SenderPayload, - ) -> Result, ReceiverCoreError> { + fn receive(&mut self, payload: SenderPayload) -> Result, ReceiverCoreError> { if self.state < State::Setup { return Err(ReceiverCoreError::NotSetup); } let hashes = self.hashes.as_ref().unwrap(); - let values: Vec = choice + let values: Vec = self + .choice + .as_ref() + .unwrap() .iter() .zip(hashes) .zip(payload.encrypted_values.iter()) diff --git a/mpc-core/src/ot/extension/receiver.rs b/mpc-core/src/ot/extension/receiver.rs index de8350ee3..9ead6818e 100644 --- a/mpc-core/src/ot/extension/receiver.rs +++ b/mpc-core/src/ot/extension/receiver.rs @@ -21,11 +21,11 @@ pub enum State { Complete, } -pub struct ExtReceiverCore { +pub struct ExtReceiverCore { cipher: C, rng: R, state: State, - base: Box, + base: OT, seeds: Option>, rngs: Option>, table: Option>>, @@ -37,18 +37,20 @@ pub struct ExtReceiverSetup { pub table: Vec>, } -impl Default for ExtReceiverCore { +impl Default for ExtReceiverCore> { fn default() -> Self { Self::new( ChaCha12Rng::from_entropy(), Aes128::new(GenericArray::from_slice(&[0u8; 16])), - Box::new(SenderCore::default()), + SenderCore::default(), ) } } -impl + BlockEncrypt> ExtReceiverCore { - pub fn new(rng: R, cipher: C, ot: Box) -> Self { +impl + BlockEncrypt, OT: SendCore> + ExtReceiverCore +{ + pub fn new(rng: R, cipher: C, ot: OT) -> Self { Self { rng, cipher, @@ -82,8 +84,11 @@ impl + BlockEncrypt> ExtRece } } -impl + BlockEncrypt> - ExtReceiveCore for ExtReceiverCore +impl< + R: Rng + CryptoRng + SeedableRng, + C: BlockCipher + BlockEncrypt, + OT: SendCore, + > ExtReceiveCore for ExtReceiverCore { fn state(&self) -> State { self.state diff --git a/mpc-core/src/ot/extension/sender.rs b/mpc-core/src/ot/extension/sender.rs index 0084c3eb7..87f3ba54e 100644 --- a/mpc-core/src/ot/extension/sender.rs +++ b/mpc-core/src/ot/extension/sender.rs @@ -93,7 +93,7 @@ impl + BlockEncrypt, OT: ReceiveCore> ExtSendCor } fn base_receive(&mut self, payload: BaseSenderPayload) -> Result<(), ExtSenderCoreError> { - let receive = self.base.receive(&self.base_choice, payload)?; + let receive = self.base.receive(payload)?; self.set_seeds(receive); self.state = State::BaseSetup; Ok(())