mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-04-28 03:00:14 -04:00
refactor ot
This commit is contained in:
@@ -15,11 +15,11 @@ async fn ot_receive(stream: UnixStream) {
|
||||
|
||||
println!("Receiver: Websocket connected");
|
||||
|
||||
let choice = [false, false, true];
|
||||
let choice = vec![false, false, true];
|
||||
|
||||
println!("Receiver: Choices: {:?}", &choice);
|
||||
|
||||
let values = receiver.receive(&mut ws_stream, &choice).await.unwrap();
|
||||
let values = receiver.receive(&mut ws_stream, choice).await.unwrap();
|
||||
|
||||
println!("Receiver: Received: {:?}", values);
|
||||
}
|
||||
|
||||
@@ -17,15 +17,21 @@ use rand_chacha::ChaCha12Rng;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
|
||||
|
||||
pub struct Generator<S> {
|
||||
pub struct Generator<S>
|
||||
where
|
||||
S: OtSend + Send,
|
||||
{
|
||||
ot: OtSender<S>,
|
||||
}
|
||||
|
||||
pub struct Evaluator<S> {
|
||||
pub struct Evaluator<S>
|
||||
where
|
||||
S: OtReceive + Send,
|
||||
{
|
||||
ot: OtReceiver<S>,
|
||||
}
|
||||
|
||||
impl<OT: OtSend> Generator<OT> {
|
||||
impl<OT: OtSend + Send> Generator<OT> {
|
||||
pub fn new(ot: OtSender<OT>) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
@@ -61,7 +67,7 @@ impl<OT: OtSend> Generator<OT> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<OT: OtReceive> Evaluator<OT> {
|
||||
impl<OT: OtReceive + Send> Evaluator<OT> {
|
||||
pub fn new(ot: OtReceiver<OT>) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
|
||||
@@ -1,69 +1,16 @@
|
||||
use mpc_core::ot::errors::{OtReceiverCoreError, OtSenderCoreError};
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use tokio::io::Error as IOError;
|
||||
use mpc_core::ot::OtMessage;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that may occur when using AsyncOTSender
|
||||
#[derive(Debug)]
|
||||
pub enum OtSenderError {
|
||||
/// Error originating from OTSender core component
|
||||
CoreError(OtSenderCoreError),
|
||||
/// Error originating from an IO Error
|
||||
IOError(IOError),
|
||||
/// Received invalid message
|
||||
MalformedMessage,
|
||||
}
|
||||
|
||||
impl Display for OtSenderError {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::CoreError(e) => write!(f, "{}", e),
|
||||
Self::IOError(e) => write!(f, "{}", e),
|
||||
Self::MalformedMessage => "malformed message".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OtSenderCoreError> for OtSenderError {
|
||||
fn from(e: OtSenderCoreError) -> Self {
|
||||
Self::CoreError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IOError> for OtSenderError {
|
||||
fn from(e: IOError) -> Self {
|
||||
Self::IOError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that may occur when using AsyncOtReceiver
|
||||
#[derive(Debug)]
|
||||
pub enum OtReceiverError {
|
||||
/// Error originating from OTSender core component
|
||||
CoreError(OtReceiverCoreError),
|
||||
/// Error originating from an IO Error
|
||||
IOError(IOError),
|
||||
/// Received invalid message
|
||||
MalformedMessage,
|
||||
}
|
||||
|
||||
impl Display for OtReceiverError {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::CoreError(e) => write!(f, "{}", e),
|
||||
Self::IOError(e) => write!(f, "{}", e),
|
||||
Self::MalformedMessage => "invalid message".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OtReceiverCoreError> for OtReceiverError {
|
||||
fn from(e: OtReceiverCoreError) -> Self {
|
||||
Self::CoreError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IOError> for OtReceiverError {
|
||||
fn from(e: IOError) -> Self {
|
||||
Self::IOError(e)
|
||||
}
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OtError {
|
||||
#[error("OT sender core error: {0}")]
|
||||
OtSenderCoreError(#[from] OtSenderCoreError),
|
||||
#[error("OT receiver core error: {0}")]
|
||||
OtReceiverCoreError(#[from] OtReceiverCoreError),
|
||||
#[error("IO error: {0}")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("Received unexpected message: {0}")]
|
||||
Unexpected(OtMessage),
|
||||
}
|
||||
|
||||
@@ -1,136 +1,7 @@
|
||||
pub mod errors;
|
||||
pub mod receiver;
|
||||
pub mod sender;
|
||||
|
||||
use errors::*;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use mpc_core::ot;
|
||||
use mpc_core::proto;
|
||||
use mpc_core::Block;
|
||||
use prost::Message as ProtoMessage;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
|
||||
|
||||
pub struct OtSender<OT> {
|
||||
ot: OT,
|
||||
}
|
||||
|
||||
pub struct OtReceiver<OT> {
|
||||
ot: OT,
|
||||
}
|
||||
|
||||
impl<OT: ot::OtSend> OtSender<OT> {
|
||||
pub fn new(ot: OT) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
|
||||
pub async fn send<S: AsyncWrite + AsyncRead + Unpin>(
|
||||
&mut self,
|
||||
stream: &mut WebSocketStream<S>,
|
||||
inputs: &[[Block; 2]],
|
||||
) -> Result<(), OtSenderError> {
|
||||
let base_sender_setup = match stream.next().await {
|
||||
Some(message) => {
|
||||
proto::BaseOtSenderSetup::decode(message.unwrap().into_data().as_slice())
|
||||
.expect("Expected BaseOtSenderSetup")
|
||||
}
|
||||
_ => return Err(OtSenderError::MalformedMessage),
|
||||
};
|
||||
|
||||
let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?;
|
||||
|
||||
stream
|
||||
.send(Message::Binary(
|
||||
proto::BaseOtReceiverSetup::from(base_setup).encode_to_vec(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let base_payload = match stream.next().await {
|
||||
Some(message) => {
|
||||
proto::BaseOtSenderPayload::decode(message.unwrap().into_data().as_slice())
|
||||
.expect("Expected BaseOtSenderPayload")
|
||||
}
|
||||
_ => return Err(OtSenderError::MalformedMessage),
|
||||
};
|
||||
self.ot.base_receive(base_payload.try_into().unwrap())?;
|
||||
|
||||
let extension_receiver_setup = match stream.next().await {
|
||||
Some(message) => {
|
||||
proto::OtReceiverSetup::decode(message.unwrap().into_data().as_slice())
|
||||
.expect("Expected OtReceiverSetup")
|
||||
}
|
||||
_ => return Err(OtSenderError::MalformedMessage),
|
||||
};
|
||||
|
||||
self.ot
|
||||
.extension_setup(extension_receiver_setup.try_into().unwrap())?;
|
||||
let payload: ot::OtSenderPayload = self.ot.send(inputs)?;
|
||||
|
||||
stream
|
||||
.send(Message::Binary(
|
||||
proto::OtSenderPayload::from(payload).encode_to_vec(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<OT: ot::OtReceive> OtReceiver<OT> {
|
||||
pub fn new(ot: OT) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
|
||||
pub async fn receive<S: AsyncWrite + AsyncRead + Unpin>(
|
||||
&mut self,
|
||||
stream: &mut WebSocketStream<S>,
|
||||
choice: &[bool],
|
||||
) -> Result<Vec<Block>, OtReceiverError> {
|
||||
let base_setup = self.ot.base_setup()?;
|
||||
|
||||
stream
|
||||
.send(Message::Binary(
|
||||
proto::BaseOtSenderSetup::from(base_setup).encode_to_vec(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let base_receiver_setup = match stream.next().await {
|
||||
Some(message) => {
|
||||
proto::BaseOtReceiverSetup::decode(message.unwrap().into_data().as_slice())
|
||||
.expect("Expected BaseOtReceiverSetup")
|
||||
}
|
||||
_ => return Err(OtReceiverError::MalformedMessage),
|
||||
};
|
||||
|
||||
let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?;
|
||||
|
||||
stream
|
||||
.send(Message::Binary(
|
||||
proto::BaseOtSenderPayload::from(payload).encode_to_vec(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let setup = self.ot.extension_setup(choice)?;
|
||||
|
||||
stream
|
||||
.send(Message::Binary(
|
||||
proto::OtReceiverSetup::from(setup).encode_to_vec(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let payload = match stream.next().await {
|
||||
Some(message) => {
|
||||
proto::OtSenderPayload::decode(message.unwrap().into_data().as_slice())
|
||||
.expect("Expected OtSenderPayload")
|
||||
}
|
||||
_ => return Err(OtReceiverError::MalformedMessage),
|
||||
};
|
||||
|
||||
let values = self.ot.receive(choice, payload.try_into().unwrap())?;
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
pub use errors::OtError;
|
||||
pub use receiver::OtReceiver;
|
||||
pub use sender::OtSender;
|
||||
|
||||
73
mpc-aio/src/ot/receiver.rs
Normal file
73
mpc-aio/src/ot/receiver.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use super::errors::OtError;
|
||||
use crate::twopc::TwoPCProtocol;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{Sink, SinkExt, Stream, StreamExt};
|
||||
use mpc_core::ot::{OtMessage, OtReceive};
|
||||
use mpc_core::Block;
|
||||
use std::io::Error as IOError;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
pub struct OtReceiver<OT>
|
||||
where
|
||||
OT: OtReceive + Send,
|
||||
{
|
||||
ot: OT,
|
||||
}
|
||||
|
||||
impl<OT: OtReceive + Send> OtReceiver<OT> {
|
||||
pub fn new(ot: OT) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<OT: OtReceive + Send> TwoPCProtocol<OtMessage> for OtReceiver<OT> {
|
||||
type Input = Vec<bool>;
|
||||
type Error = OtError;
|
||||
type Output = Vec<Block>;
|
||||
|
||||
async fn run<
|
||||
S: Sink<OtMessage> + Stream<Item = Result<OtMessage, E>> + Send + Unpin,
|
||||
E: std::fmt::Debug,
|
||||
>(
|
||||
&mut self,
|
||||
stream: &mut S,
|
||||
input: Self::Input,
|
||||
) -> Result<Self::Output, Self::Error>
|
||||
where
|
||||
Self::Error: From<<S as Sink<OtMessage>>::Error>,
|
||||
Self::Error: From<E>,
|
||||
{
|
||||
let base_setup = self.ot.base_setup()?;
|
||||
|
||||
stream.send(OtMessage::BaseSenderSetup(base_setup)).await?;
|
||||
|
||||
let base_receiver_setup = match stream.next().await {
|
||||
Some(Ok(OtMessage::BaseReceiverSetup(m))) => m,
|
||||
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
|
||||
Some(Err(e)) => return Err(e)?,
|
||||
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
|
||||
};
|
||||
|
||||
let payload = self.ot.base_send(base_receiver_setup.try_into().unwrap())?;
|
||||
|
||||
stream.send(OtMessage::BaseSenderPayload(payload)).await?;
|
||||
|
||||
let setup = self.ot.extension_setup(input.as_slice())?;
|
||||
|
||||
stream.send(OtMessage::ReceiverSetup(setup)).await?;
|
||||
|
||||
let payload = match stream.next().await {
|
||||
Some(Ok(OtMessage::SenderPayload(m))) => m,
|
||||
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
|
||||
Some(Err(e)) => return Err(e)?,
|
||||
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
|
||||
};
|
||||
|
||||
let values = self
|
||||
.ot
|
||||
.receive(input.as_slice(), payload.try_into().unwrap())?;
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
77
mpc-aio/src/ot/sender.rs
Normal file
77
mpc-aio/src/ot/sender.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use super::errors::OtError;
|
||||
use crate::twopc::TwoPCProtocol;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{Sink, SinkExt, Stream, StreamExt};
|
||||
use mpc_core::ot::{OtMessage, OtSend};
|
||||
use mpc_core::Block;
|
||||
use std::io::Error as IOError;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
pub struct OtSender<OT>
|
||||
where
|
||||
OT: OtSend + Send,
|
||||
{
|
||||
ot: OT,
|
||||
}
|
||||
|
||||
impl<OT: OtSend + Send> OtSender<OT> {
|
||||
pub fn new(ot: OT) -> Self {
|
||||
Self { ot }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<OT: OtSend + Send> TwoPCProtocol<OtMessage> for OtSender<OT> {
|
||||
type Input = Vec<[Block; 2]>;
|
||||
type Error = OtError;
|
||||
type Output = ();
|
||||
|
||||
async fn run<
|
||||
S: Sink<OtMessage> + Stream<Item = Result<OtMessage, E>> + Send + Unpin,
|
||||
E: std::fmt::Debug,
|
||||
>(
|
||||
&mut self,
|
||||
stream: &mut S,
|
||||
input: Self::Input,
|
||||
) -> Result<Self::Output, Self::Error>
|
||||
where
|
||||
Self::Error: From<<S as Sink<OtMessage>>::Error>,
|
||||
Self::Error: From<E>,
|
||||
{
|
||||
let base_sender_setup = match stream.next().await {
|
||||
Some(Ok(OtMessage::BaseSenderSetup(m))) => m,
|
||||
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
|
||||
Some(Err(e)) => return Err(e)?,
|
||||
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
|
||||
};
|
||||
|
||||
let base_setup = self.ot.base_setup(base_sender_setup.try_into().unwrap())?;
|
||||
|
||||
stream
|
||||
.send(OtMessage::BaseReceiverSetup(base_setup))
|
||||
.await?;
|
||||
|
||||
let base_payload = match stream.next().await {
|
||||
Some(Ok(OtMessage::BaseSenderPayload(m))) => m,
|
||||
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
|
||||
Some(Err(e)) => return Err(e)?,
|
||||
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
|
||||
};
|
||||
self.ot.base_receive(base_payload.try_into().unwrap())?;
|
||||
|
||||
let extension_receiver_setup = match stream.next().await {
|
||||
Some(Ok(OtMessage::ReceiverSetup(m))) => m,
|
||||
Some(Ok(m)) => return Err(OtError::Unexpected(m)),
|
||||
Some(Err(e)) => return Err(e)?,
|
||||
None => return Err(IOError::new(ErrorKind::UnexpectedEof, ""))?,
|
||||
};
|
||||
|
||||
self.ot
|
||||
.extension_setup(extension_receiver_setup.try_into().unwrap())?;
|
||||
let payload = self.ot.send(&input)?;
|
||||
|
||||
stream.send(OtMessage::SenderPayload(payload)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user