diff --git a/mpc-aio/examples/ot_ws.rs b/mpc-aio/examples/ot_ws.rs index f4bf6fdac..274fcd1b3 100644 --- a/mpc-aio/examples/ot_ws.rs +++ b/mpc-aio/examples/ot_ws.rs @@ -4,7 +4,6 @@ use futures::{AsyncRead, AsyncWrite}; use mpc_aio::ot::{ ExtOTReceive, ExtOTSend, ExtReceiver, ExtSender, OTReceive, OTSend, Receiver, Sender, }; -use mpc_core::ot::{DhOtSender, ExtReceiverCore, ExtSenderCore, ReceiverCore}; use mpc_core::Block; use rand::{thread_rng, Rng}; use tokio; @@ -61,10 +60,10 @@ async fn receive( info!("Choosing {:?}", choice); let values = if extended { - let mut receiver = ExtReceiver::new(ExtReceiverCore::new(choice.len()), stream); + let mut receiver = ExtReceiver::new(stream, choice.len()); receiver.receive(&choice).await.unwrap() } else { - let mut receiver = Receiver::new(ReceiverCore::new(choice.len()), stream); + let mut receiver = Receiver::new(stream); receiver.receive(&choice).await.unwrap() }; @@ -80,10 +79,10 @@ async fn send( let stream = WsStream::new(ws); if extended { - let mut sender = ExtSender::new(ExtSenderCore::new(values.len()), stream); + let mut sender = ExtSender::new(stream, values.len()); let _ = sender.send(&values).await; } else { - let mut sender = Sender::new(DhOtSender::new(values.len()), stream); + let mut sender = Sender::new(stream); let _ = sender.send(&values).await; } diff --git a/mpc-aio/src/ot/base/receiver.rs b/mpc-aio/src/ot/base/receiver.rs index 688139ccc..8eb9675bf 100644 --- a/mpc-aio/src/ot/base/receiver.rs +++ b/mpc-aio/src/ot/base/receiver.rs @@ -1,9 +1,10 @@ use super::OTError; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; -use mpc_core::ot::{Message, ReceiveCore}; +use mpc_core::ot::{DhOtReceiver, Message}; use mpc_core::proto::ot::Message as ProtoMessage; use mpc_core::Block; +use rand::thread_rng; use std::io::Error as IOError; use std::io::ErrorKind; use tokio::io::{AsyncRead, AsyncWrite}; @@ -11,8 +12,8 @@ use tokio_util::codec::Framed; use tracing::{instrument, trace}; use utils_aio::codec::ProstCodecDelimited; -pub struct Receiver { - ot: OT, +pub struct Receiver { + ot: DhOtReceiver, stream: Framed>, } @@ -21,14 +22,13 @@ pub trait OTReceive { async fn receive(&mut self, choice: &[bool]) -> Result, OTError>; } -impl Receiver +impl Receiver where - OT: ReceiveCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { - pub fn new(ot: OT, stream: S) -> Self { + pub fn new(stream: S) -> Self { Self { - ot, + ot: DhOtReceiver::default(), stream: Framed::new( stream, ProstCodecDelimited::::default(), @@ -38,28 +38,27 @@ where } #[async_trait] -impl OTReceive for Receiver +impl OTReceive for Receiver where - OT: ReceiveCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { #[instrument(skip(self, choice))] async fn receive(&mut self, choice: &[bool]) -> Result, OTError> { let setup = match self.stream.next().await { - Some(Ok(Message::SenderSetup(m))) => m, + Some(Ok(Message::BaseSenderSetup(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, }; trace!("Received SenderSetup: {:?}", &setup); - let setup = self.ot.setup(choice, setup)?; + let setup = self.ot.setup(&mut thread_rng(), choice, setup)?; trace!("Sending ReceiverSetup: {:?}", &setup); - self.stream.send(Message::ReceiverSetup(setup)).await?; + self.stream.send(Message::BaseReceiverSetup(setup)).await?; let payload = match self.stream.next().await { - Some(Ok(Message::SenderOutput(m))) => m, + Some(Ok(Message::BaseSenderPayload(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, diff --git a/mpc-aio/src/ot/base/sender.rs b/mpc-aio/src/ot/base/sender.rs index 9d6617fd6..7e06d74b3 100644 --- a/mpc-aio/src/ot/base/sender.rs +++ b/mpc-aio/src/ot/base/sender.rs @@ -1,9 +1,10 @@ use super::OTError; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; -use mpc_core::ot::{Message, SendCore}; +use mpc_core::ot::{DhOtSender, Message}; use mpc_core::proto::ot::Message as ProtoMessage; use mpc_core::Block; +use rand::thread_rng; use std::io::Error as IOError; use std::io::ErrorKind; use tokio::io::{AsyncRead, AsyncWrite}; @@ -11,8 +12,8 @@ use tokio_util::codec::Framed; use tracing::{instrument, trace}; use utils_aio::codec::ProstCodecDelimited; -pub struct Sender { - ot: OT, +pub struct Sender { + ot: DhOtSender, stream: Framed>, } @@ -21,14 +22,13 @@ pub trait OTSend { async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError>; } -impl Sender +impl Sender where - OT: SendCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { - pub fn new(ot: OT, stream: S) -> Self { + pub fn new(stream: S) -> Self { Self { - ot, + ot: DhOtSender::default(), stream: Framed::new( stream, ProstCodecDelimited::::default(), @@ -38,20 +38,19 @@ where } #[async_trait] -impl OTSend for Sender +impl OTSend for Sender where - OT: SendCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { #[instrument(skip(self, payload))] async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError> { - let setup = self.ot.setup(); + let setup = self.ot.setup(&mut thread_rng())?; trace!("Sending SenderSetup: {:?}", &setup); - self.stream.send(Message::SenderSetup(setup)).await?; + self.stream.send(Message::BaseSenderSetup(setup)).await?; let setup = match self.stream.next().await { - Some(Ok(Message::ReceiverSetup(m))) => m, + Some(Ok(Message::BaseReceiverSetup(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, @@ -61,7 +60,9 @@ where let payload = self.ot.send(payload, setup)?; trace!("Sending SenderPayload: {:?}", &payload); - self.stream.send(Message::SenderOutput(payload)).await?; + self.stream + .send(Message::BaseSenderPayload(payload)) + .await?; Ok(()) } diff --git a/mpc-aio/src/ot/extension/receiver.rs b/mpc-aio/src/ot/extension/receiver.rs index a9b169eae..e6ef53e7c 100644 --- a/mpc-aio/src/ot/extension/receiver.rs +++ b/mpc-aio/src/ot/extension/receiver.rs @@ -1,7 +1,7 @@ use super::OTError; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; -use mpc_core::ot::{ExtStandardReceiveCore, Message}; +use mpc_core::ot::{Kos15Receiver, Message}; use mpc_core::proto::ot::Message as ProtoMessage; use mpc_core::Block; use std::io::Error as IOError; @@ -11,8 +11,8 @@ use tokio_util::codec::Framed; use tracing::{instrument, trace}; use utils_aio::codec::ProstCodecDelimited; -pub struct ExtReceiver { - ot: OT, +pub struct ExtReceiver { + ot: Kos15Receiver, stream: Framed>, } @@ -21,14 +21,13 @@ pub trait ExtOTReceive { async fn receive(&mut self, choice: &[bool]) -> Result, OTError>; } -impl ExtReceiver +impl ExtReceiver where - OT: ExtStandardReceiveCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { - pub fn new(ot: OT, stream: S) -> Self { + pub fn new(stream: S, count: usize) -> Self { Self { - ot, + ot: Kos15Receiver::new(count), stream: Framed::new( stream, ProstCodecDelimited::::default(), @@ -38,22 +37,21 @@ where } #[async_trait] -impl ExtOTReceive for ExtReceiver +impl ExtOTReceive for ExtReceiver where - OT: ExtStandardReceiveCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { #[instrument(skip(self, choice))] async fn receive(&mut self, choice: &[bool]) -> Result, OTError> { let base_setup = self.ot.base_setup()?; - trace!("Sending BaseSenderSetup"); + trace!("Sending BaseSenderSetupWrapper"); self.stream - .send(Message::BaseSenderSetup(base_setup)) + .send(Message::BaseSenderSetupWrapper(base_setup)) .await?; let base_receiver_setup = match self.stream.next().await { - Some(Ok(Message::BaseReceiverSetup(m))) => m, + Some(Ok(Message::BaseReceiverSetupWrapper(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, @@ -62,9 +60,9 @@ where let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?; - trace!("Sending BaseSenderPayload"); + trace!("Sending BaseSenderPayloadWrapper"); self.stream - .send(Message::BaseSenderPayload(payload)) + .send(Message::BaseSenderPayloadWrapper(payload)) .await?; let setup = self.ot.extension_setup(choice)?; diff --git a/mpc-aio/src/ot/extension/sender.rs b/mpc-aio/src/ot/extension/sender.rs index e441b01ba..301308d11 100644 --- a/mpc-aio/src/ot/extension/sender.rs +++ b/mpc-aio/src/ot/extension/sender.rs @@ -1,7 +1,7 @@ use super::OTError; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; -use mpc_core::ot::{ExtStandardSendCore, Message}; +use mpc_core::ot::{Kos15Sender, Message}; use mpc_core::proto::ot::Message as ProtoMessage; use mpc_core::Block; use std::io::Error as IOError; @@ -11,8 +11,8 @@ use tokio_util::codec::Framed; use tracing::{instrument, trace}; use utils_aio::codec::ProstCodecDelimited; -pub struct ExtSender { - ot: OT, +pub struct ExtSender { + ot: Kos15Sender, stream: Framed>, } @@ -21,14 +21,13 @@ pub trait ExtOTSend { async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError>; } -impl ExtSender +impl ExtSender where - OT: ExtStandardSendCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { - pub fn new(ot: OT, stream: S) -> Self { + pub fn new(stream: S, count: usize) -> Self { Self { - ot, + ot: Kos15Sender::new(count), stream: Framed::new( stream, ProstCodecDelimited::::default(), @@ -38,15 +37,14 @@ where } #[async_trait] -impl ExtOTSend for ExtSender +impl ExtOTSend for ExtSender where - OT: ExtStandardSendCore + Send, S: AsyncRead + AsyncWrite + Send + Unpin, { #[instrument(skip(self, payload))] async fn send(&mut self, payload: &[[Block; 2]]) -> Result<(), OTError> { let base_sender_setup = match self.stream.next().await { - Some(Ok(Message::BaseSenderSetup(m))) => m, + Some(Ok(Message::BaseSenderSetupWrapper(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?, @@ -57,11 +55,11 @@ where trace!("Sending ReceiverSetup"); self.stream - .send(Message::BaseReceiverSetup(base_setup)) + .send(Message::BaseReceiverSetupWrapper(base_setup)) .await?; let base_payload = match self.stream.next().await { - Some(Ok(Message::BaseSenderPayload(m))) => m, + Some(Ok(Message::BaseSenderPayloadWrapper(m))) => m, Some(Ok(m)) => return Err(OTError::Unexpected(m)), Some(Err(e)) => return Err(e)?, None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,