diff --git a/mpc-aio/examples/ot_ws.rs b/mpc-aio/examples/ot_ws.rs index 75edc1dad..eed224c81 100644 --- a/mpc-aio/examples/ot_ws.rs +++ b/mpc-aio/examples/ot_ws.rs @@ -15,11 +15,11 @@ async fn ot_receive(stream: UnixStream) { println!("Receiver: Websocket connected"); - let choice = [false, false, true]; + let choice = vec![false, false, true]; println!("Receiver: Choices: {:?}", &choice); - let values = receiver.receive(&mut ws_stream, &choice).await.unwrap(); + let values = receiver.receive(&mut ws_stream, choice).await.unwrap(); println!("Receiver: Received: {:?}", values); } diff --git a/mpc-aio/src/garble/mod.rs b/mpc-aio/src/garble/mod.rs index b5953b967..146ecabaf 100644 --- a/mpc-aio/src/garble/mod.rs +++ b/mpc-aio/src/garble/mod.rs @@ -17,15 +17,21 @@ use rand_chacha::ChaCha12Rng; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; -pub struct Generator { +pub struct Generator +where + S: OtSend + Send, +{ ot: OtSender, } -pub struct Evaluator { +pub struct Evaluator +where + S: OtReceive + Send, +{ ot: OtReceiver, } -impl Generator { +impl Generator { pub fn new(ot: OtSender) -> Self { Self { ot } } @@ -61,7 +67,7 @@ impl Generator { } } -impl Evaluator { +impl Evaluator { pub fn new(ot: OtReceiver) -> Self { Self { ot } } diff --git a/mpc-aio/src/ot/errors.rs b/mpc-aio/src/ot/errors.rs index 1c576b378..217956e22 100644 --- a/mpc-aio/src/ot/errors.rs +++ b/mpc-aio/src/ot/errors.rs @@ -1,69 +1,16 @@ use mpc_core::ot::errors::{OtReceiverCoreError, OtSenderCoreError}; -use std::fmt::{self, Display, Formatter}; -use tokio::io::Error as IOError; +use mpc_core::ot::OtMessage; +use thiserror::Error; /// Errors that may occur when using AsyncOTSender -#[derive(Debug)] -pub enum OtSenderError { - /// Error originating from OTSender core component - CoreError(OtSenderCoreError), - /// Error originating from an IO Error - IOError(IOError), - /// Received invalid message - MalformedMessage, -} - -impl Display for OtSenderError { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - Self::CoreError(e) => write!(f, "{}", e), - Self::IOError(e) => write!(f, "{}", e), - Self::MalformedMessage => "malformed message".fmt(f), - } - } -} - -impl From for OtSenderError { - fn from(e: OtSenderCoreError) -> Self { - Self::CoreError(e) - } -} - -impl From for OtSenderError { - fn from(e: IOError) -> Self { - Self::IOError(e) - } -} - -/// Errors that may occur when using AsyncOtReceiver -#[derive(Debug)] -pub enum OtReceiverError { - /// Error originating from OTSender core component - CoreError(OtReceiverCoreError), - /// Error originating from an IO Error - IOError(IOError), - /// Received invalid message - MalformedMessage, -} - -impl Display for OtReceiverError { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - Self::CoreError(e) => write!(f, "{}", e), - Self::IOError(e) => write!(f, "{}", e), - Self::MalformedMessage => "invalid message".fmt(f), - } - } -} - -impl From for OtReceiverError { - fn from(e: OtReceiverCoreError) -> Self { - Self::CoreError(e) - } -} - -impl From for OtReceiverError { - fn from(e: IOError) -> Self { - Self::IOError(e) - } +#[derive(Debug, Error)] +pub enum OtError { + #[error("OT sender core error: {0}")] + OtSenderCoreError(#[from] OtSenderCoreError), + #[error("OT receiver core error: {0}")] + OtReceiverCoreError(#[from] OtReceiverCoreError), + #[error("IO error: {0}")] + IOError(#[from] std::io::Error), + #[error("Received unexpected message: {0}")] + Unexpected(OtMessage), } diff --git a/mpc-aio/src/ot/mod.rs b/mpc-aio/src/ot/mod.rs index 2f5bcceb6..852ad87ce 100644 --- a/mpc-aio/src/ot/mod.rs +++ b/mpc-aio/src/ot/mod.rs @@ -1,136 +1,7 @@ pub mod errors; +pub mod receiver; +pub mod sender; -use errors::*; -use futures_util::{SinkExt, StreamExt}; -use mpc_core::ot; -use mpc_core::proto; -use mpc_core::Block; -use prost::Message as ProtoMessage; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; - -pub struct OtSender { - ot: OT, -} - -pub struct OtReceiver { - ot: OT, -} - -impl OtSender { - pub fn new(ot: OT) -> Self { - Self { ot } - } - - pub async fn send( - &mut self, - stream: &mut WebSocketStream, - inputs: &[[Block; 2]], - ) -> Result<(), OtSenderError> { - let base_sender_setup = match stream.next().await { - Some(message) => { - proto::BaseOtSenderSetup::decode(message.unwrap().into_data().as_slice()) - .expect("Expected BaseOtSenderSetup") - } - _ => return Err(OtSenderError::MalformedMessage), - }; - - let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?; - - stream - .send(Message::Binary( - proto::BaseOtReceiverSetup::from(base_setup).encode_to_vec(), - )) - .await - .unwrap(); - - let base_payload = match stream.next().await { - Some(message) => { - proto::BaseOtSenderPayload::decode(message.unwrap().into_data().as_slice()) - .expect("Expected BaseOtSenderPayload") - } - _ => return Err(OtSenderError::MalformedMessage), - }; - self.ot.base_receive(base_payload.try_into().unwrap())?; - - let extension_receiver_setup = match stream.next().await { - Some(message) => { - proto::OtReceiverSetup::decode(message.unwrap().into_data().as_slice()) - .expect("Expected OtReceiverSetup") - } - _ => return Err(OtSenderError::MalformedMessage), - }; - - self.ot - .extension_setup(extension_receiver_setup.try_into().unwrap())?; - let payload: ot::OtSenderPayload = self.ot.send(inputs)?; - - stream - .send(Message::Binary( - proto::OtSenderPayload::from(payload).encode_to_vec(), - )) - .await - .unwrap(); - - Ok(()) - } -} - -impl OtReceiver { - pub fn new(ot: OT) -> Self { - Self { ot } - } - - pub async fn receive( - &mut self, - stream: &mut WebSocketStream, - choice: &[bool], - ) -> Result, OtReceiverError> { - let base_setup = self.ot.base_setup()?; - - stream - .send(Message::Binary( - proto::BaseOtSenderSetup::from(base_setup).encode_to_vec(), - )) - .await - .unwrap(); - - let base_receiver_setup = match stream.next().await { - Some(message) => { - proto::BaseOtReceiverSetup::decode(message.unwrap().into_data().as_slice()) - .expect("Expected BaseOtReceiverSetup") - } - _ => return Err(OtReceiverError::MalformedMessage), - }; - - let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?; - - stream - .send(Message::Binary( - proto::BaseOtSenderPayload::from(payload).encode_to_vec(), - )) - .await - .unwrap(); - - let setup = self.ot.extension_setup(choice)?; - - stream - .send(Message::Binary( - proto::OtReceiverSetup::from(setup).encode_to_vec(), - )) - .await - .unwrap(); - - let payload = match stream.next().await { - Some(message) => { - proto::OtSenderPayload::decode(message.unwrap().into_data().as_slice()) - .expect("Expected OtSenderPayload") - } - _ => return Err(OtReceiverError::MalformedMessage), - }; - - let values = self.ot.receive(choice, payload.try_into().unwrap())?; - - Ok(values) - } -} +pub use errors::OtError; +pub use receiver::OtReceiver; +pub use sender::OtSender; diff --git a/mpc-aio/src/ot/receiver.rs b/mpc-aio/src/ot/receiver.rs new file mode 100644 index 000000000..dd200ea73 --- /dev/null +++ b/mpc-aio/src/ot/receiver.rs @@ -0,0 +1,73 @@ +use super::errors::OtError; +use crate::twopc::TwoPCProtocol; +use async_trait::async_trait; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use mpc_core::ot::{OtMessage, OtReceive}; +use mpc_core::Block; +use std::io::Error as IOError; +use std::io::ErrorKind; + +pub struct OtReceiver +where + OT: OtReceive + Send, +{ + ot: OT, +} + +impl OtReceiver { + pub fn new(ot: OT) -> Self { + Self { ot } + } +} + +#[async_trait] +impl TwoPCProtocol for OtReceiver { + type Input = Vec; + type Error = OtError; + type Output = Vec; + + async fn run< + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + >( + &mut self, + stream: &mut S, + input: Self::Input, + ) -> Result + where + Self::Error: From<>::Error>, + Self::Error: From, + { + let base_setup = self.ot.base_setup()?; + + stream.send(OtMessage::BaseSenderSetup(base_setup)).await?; + + let base_receiver_setup = match stream.next().await { + Some(Ok(OtMessage::BaseReceiverSetup(m))) => m, + Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + + let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?; + + stream.send(OtMessage::BaseSenderPayload(payload)).await?; + + let setup = self.ot.extension_setup(input.as_slice())?; + + stream.send(OtMessage::ReceiverSetup(setup)).await?; + + let payload = match stream.next().await { + Some(Ok(OtMessage::SenderPayload(m))) => m, + Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + + let values = self + .ot + .receive(input.as_slice(), payload.try_into().unwrap())?; + + Ok(values) + } +} diff --git a/mpc-aio/src/ot/sender.rs b/mpc-aio/src/ot/sender.rs new file mode 100644 index 000000000..21ff03501 --- /dev/null +++ b/mpc-aio/src/ot/sender.rs @@ -0,0 +1,77 @@ +use super::errors::OtError; +use crate::twopc::TwoPCProtocol; +use async_trait::async_trait; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use mpc_core::ot::{OtMessage, OtSend}; +use mpc_core::Block; +use std::io::Error as IOError; +use std::io::ErrorKind; + +pub struct OtSender +where + OT: OtSend + Send, +{ + ot: OT, +} + +impl OtSender { + pub fn new(ot: OT) -> Self { + Self { ot } + } +} + +#[async_trait] +impl TwoPCProtocol for OtSender { + type Input = Vec<[Block; 2]>; + type Error = OtError; + type Output = (); + + async fn run< + S: Sink + Stream> + Send + Unpin, + E: std::fmt::Debug, + >( + &mut self, + stream: &mut S, + input: Self::Input, + ) -> Result + where + Self::Error: From<>::Error>, + Self::Error: From, + { + let base_sender_setup = match stream.next().await { + Some(Ok(OtMessage::BaseSenderSetup(m))) => m, + Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + + let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?; + + stream + .send(OtMessage::BaseReceiverSetup(base_setup)) + .await?; + + let base_payload = match stream.next().await { + Some(Ok(OtMessage::BaseSenderPayload(m))) => m, + Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + self.ot.base_receive(base_payload.try_into().unwrap())?; + + let extension_receiver_setup = match stream.next().await { + Some(Ok(OtMessage::ReceiverSetup(m))) => m, + Some(Ok(m)) => return Err(OtError::Unexpected(m)), + Some(Err(e)) => return Err(e)?, + None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, + }; + + self.ot + .extension_setup(extension_receiver_setup.try_into().unwrap())?; + let payload = self.ot.send(&input)?; + + stream.send(OtMessage::SenderPayload(payload)).await?; + + Ok(()) + } +}