refactor ot

This commit is contained in:
sinuio
2022-03-31 21:19:40 -07:00
parent 615b8ed46e
commit 005187fac5
6 changed files with 179 additions and 205 deletions

View File

@@ -15,11 +15,11 @@ async fn ot_receive(stream: UnixStream) {
println!("Receiver: Websocket connected");
let choice = [false, false, true];
let choice = vec![false, false, true];
println!("Receiver: Choices: {:?}", &choice);
let values = receiver.receive(&mut ws_stream, &choice).await.unwrap();
let values = receiver.receive(&mut ws_stream, choice).await.unwrap();
println!("Receiver: Received: {:?}", values);
}

View File

@@ -17,15 +17,21 @@ use rand_chacha::ChaCha12Rng;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
pub struct Generator<S> {
pub struct Generator<S>
where
S: OtSend + Send,
{
ot: OtSender<S>,
}
pub struct Evaluator<S> {
pub struct Evaluator<S>
where
S: OtReceive + Send,
{
ot: OtReceiver<S>,
}
impl<OT: OtSend> Generator<OT> {
impl<OT: OtSend + Send> Generator<OT> {
pub fn new(ot: OtSender<OT>) -> Self {
Self { ot }
}
@@ -61,7 +67,7 @@ impl<OT: OtSend> Generator<OT> {
}
}
impl<OT: OtReceive> Evaluator<OT> {
impl<OT: OtReceive + Send> Evaluator<OT> {
pub fn new(ot: OtReceiver<OT>) -> Self {
Self { ot }
}

View File

@@ -1,69 +1,16 @@
use mpc_core::ot::errors::{OtReceiverCoreError, OtSenderCoreError};
use std::fmt::{self, Display, Formatter};
use tokio::io::Error as IOError;
use mpc_core::ot::OtMessage;
use thiserror::Error;
/// Errors that may occur when using AsyncOTSender
#[derive(Debug)]
pub enum OtSenderError {
/// Error originating from OTSender core component
CoreError(OtSenderCoreError),
/// Error originating from an IO Error
IOError(IOError),
/// Received invalid message
MalformedMessage,
}
impl Display for OtSenderError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::CoreError(e) => write!(f, "{}", e),
Self::IOError(e) => write!(f, "{}", e),
Self::MalformedMessage => "malformed message".fmt(f),
}
}
}
impl From<OtSenderCoreError> for OtSenderError {
fn from(e: OtSenderCoreError) -> Self {
Self::CoreError(e)
}
}
impl From<IOError> for OtSenderError {
fn from(e: IOError) -> Self {
Self::IOError(e)
}
}
/// Errors that may occur when using AsyncOtReceiver
#[derive(Debug)]
pub enum OtReceiverError {
/// Error originating from OTSender core component
CoreError(OtReceiverCoreError),
/// Error originating from an IO Error
IOError(IOError),
/// Received invalid message
MalformedMessage,
}
impl Display for OtReceiverError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::CoreError(e) => write!(f, "{}", e),
Self::IOError(e) => write!(f, "{}", e),
Self::MalformedMessage => "invalid message".fmt(f),
}
}
}
impl From<OtReceiverCoreError> for OtReceiverError {
fn from(e: OtReceiverCoreError) -> Self {
Self::CoreError(e)
}
}
impl From<IOError> for OtReceiverError {
fn from(e: IOError) -> Self {
Self::IOError(e)
}
#[derive(Debug, Error)]
pub enum OtError {
#[error("OT sender core error: {0}")]
OtSenderCoreError(#[from] OtSenderCoreError),
#[error("OT receiver core error: {0}")]
OtReceiverCoreError(#[from] OtReceiverCoreError),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Received unexpected message: {0}")]
Unexpected(OtMessage),
}

View File

@@ -1,136 +1,7 @@
pub mod errors;
pub mod receiver;
pub mod sender;
use errors::*;
use futures_util::{SinkExt, StreamExt};
use mpc_core::ot;
use mpc_core::proto;
use mpc_core::Block;
use prost::Message as ProtoMessage;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
pub struct OtSender<OT> {
ot: OT,
}
pub struct OtReceiver<OT> {
ot: OT,
}
impl<OT: ot::OtSend> OtSender<OT> {
pub fn new(ot: OT) -> Self {
Self { ot }
}
pub async fn send<S: AsyncWrite + AsyncRead + Unpin>(
&mut self,
stream: &mut WebSocketStream<S>,
inputs: &[[Block; 2]],
) -> Result<(), OtSenderError> {
let base_sender_setup = match stream.next().await {
Some(message) => {
proto::BaseOtSenderSetup::decode(message.unwrap().into_data().as_slice())
.expect("Expected BaseOtSenderSetup")
}
_ => return Err(OtSenderError::MalformedMessage),
};
let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?;
stream
.send(Message::Binary(
proto::BaseOtReceiverSetup::from(base_setup).encode_to_vec(),
))
.await
.unwrap();
let base_payload = match stream.next().await {
Some(message) => {
proto::BaseOtSenderPayload::decode(message.unwrap().into_data().as_slice())
.expect("Expected BaseOtSenderPayload")
}
_ => return Err(OtSenderError::MalformedMessage),
};
self.ot.base_receive(base_payload.try_into().unwrap())?;
let extension_receiver_setup = match stream.next().await {
Some(message) => {
proto::OtReceiverSetup::decode(message.unwrap().into_data().as_slice())
.expect("Expected OtReceiverSetup")
}
_ => return Err(OtSenderError::MalformedMessage),
};
self.ot
.extension_setup(extension_receiver_setup.try_into().unwrap())?;
let payload: ot::OtSenderPayload = self.ot.send(inputs)?;
stream
.send(Message::Binary(
proto::OtSenderPayload::from(payload).encode_to_vec(),
))
.await
.unwrap();
Ok(())
}
}
impl<OT: ot::OtReceive> OtReceiver<OT> {
pub fn new(ot: OT) -> Self {
Self { ot }
}
pub async fn receive<S: AsyncWrite + AsyncRead + Unpin>(
&mut self,
stream: &mut WebSocketStream<S>,
choice: &[bool],
) -> Result<Vec<Block>, OtReceiverError> {
let base_setup = self.ot.base_setup()?;
stream
.send(Message::Binary(
proto::BaseOtSenderSetup::from(base_setup).encode_to_vec(),
))
.await
.unwrap();
let base_receiver_setup = match stream.next().await {
Some(message) => {
proto::BaseOtReceiverSetup::decode(message.unwrap().into_data().as_slice())
.expect("Expected BaseOtReceiverSetup")
}
_ => return Err(OtReceiverError::MalformedMessage),
};
let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?;
stream
.send(Message::Binary(
proto::BaseOtSenderPayload::from(payload).encode_to_vec(),
))
.await
.unwrap();
let setup = self.ot.extension_setup(choice)?;
stream
.send(Message::Binary(
proto::OtReceiverSetup::from(setup).encode_to_vec(),
))
.await
.unwrap();
let payload = match stream.next().await {
Some(message) => {
proto::OtSenderPayload::decode(message.unwrap().into_data().as_slice())
.expect("Expected OtSenderPayload")
}
_ => return Err(OtReceiverError::MalformedMessage),
};
let values = self.ot.receive(choice, payload.try_into().unwrap())?;
Ok(values)
}
}
pub use errors::OtError;
pub use receiver::OtReceiver;
pub use sender::OtSender;

View File

@@ -0,0 +1,73 @@
use super::errors::OtError;
use crate::twopc::TwoPCProtocol;
use async_trait::async_trait;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use mpc_core::ot::{OtMessage, OtReceive};
use mpc_core::Block;
use std::io::Error as IOError;
use std::io::ErrorKind;
pub struct OtReceiver<OT>
where
OT: OtReceive + Send,
{
ot: OT,
}
impl<OT: OtReceive + Send> OtReceiver<OT> {
pub fn new(ot: OT) -> Self {
Self { ot }
}
}
#[async_trait]
impl<OT: OtReceive + Send> TwoPCProtocol<OtMessage> for OtReceiver<OT> {
type Input = Vec<bool>;
type Error = OtError;
type Output = Vec<Block>;
async fn run<
S: Sink<OtMessage> + Stream<Item = Result<OtMessage, E>> + Send + Unpin,
E: std::fmt::Debug,
>(
&mut self,
stream: &mut S,
input: Self::Input,
) -> Result<Self::Output, Self::Error>
where
Self::Error: From<<S as Sink<OtMessage>>::Error>,
Self::Error: From<E>,
{
let base_setup = self.ot.base_setup()?;
stream.send(OtMessage::BaseSenderSetup(base_setup)).await?;
let base_receiver_setup = match stream.next().await {
Some(Ok(OtMessage::BaseReceiverSetup(m))) => m,
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
Some(Err(e)) => return Err(e)?,
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
};
let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?;
stream.send(OtMessage::BaseSenderPayload(payload)).await?;
let setup = self.ot.extension_setup(input.as_slice())?;
stream.send(OtMessage::ReceiverSetup(setup)).await?;
let payload = match stream.next().await {
Some(Ok(OtMessage::SenderPayload(m))) => m,
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
Some(Err(e)) => return Err(e)?,
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
};
let values = self
.ot
.receive(input.as_slice(), payload.try_into().unwrap())?;
Ok(values)
}
}

77
mpc-aio/src/ot/sender.rs Normal file
View File

@@ -0,0 +1,77 @@
use super::errors::OtError;
use crate::twopc::TwoPCProtocol;
use async_trait::async_trait;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use mpc_core::ot::{OtMessage, OtSend};
use mpc_core::Block;
use std::io::Error as IOError;
use std::io::ErrorKind;
pub struct OtSender<OT>
where
OT: OtSend + Send,
{
ot: OT,
}
impl<OT: OtSend + Send> OtSender<OT> {
pub fn new(ot: OT) -> Self {
Self { ot }
}
}
#[async_trait]
impl<OT: OtSend + Send> TwoPCProtocol<OtMessage> for OtSender<OT> {
type Input = Vec<[Block; 2]>;
type Error = OtError;
type Output = ();
async fn run<
S: Sink<OtMessage> + Stream<Item = Result<OtMessage, E>> + Send + Unpin,
E: std::fmt::Debug,
>(
&mut self,
stream: &mut S,
input: Self::Input,
) -> Result<Self::Output, Self::Error>
where
Self::Error: From<<S as Sink<OtMessage>>::Error>,
Self::Error: From<E>,
{
let base_sender_setup = match stream.next().await {
Some(Ok(OtMessage::BaseSenderSetup(m))) => m,
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
Some(Err(e)) => return Err(e)?,
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
};
let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?;
stream
.send(OtMessage::BaseReceiverSetup(base_setup))
.await?;
let base_payload = match stream.next().await {
Some(Ok(OtMessage::BaseSenderPayload(m))) => m,
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
Some(Err(e)) => return Err(e)?,
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
};
self.ot.base_receive(base_payload.try_into().unwrap())?;
let extension_receiver_setup = match stream.next().await {
Some(Ok(OtMessage::ReceiverSetup(m))) => m,
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
Some(Err(e)) => return Err(e)?,
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
};
self.ot
.extension_setup(extension_receiver_setup.try_into().unwrap())?;
let payload = self.ot.send(&input)?;
stream.send(OtMessage::SenderPayload(payload)).await?;
Ok(())
}
}