mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-04-28 03:00:14 -04:00
Refactor OT traits to be generics (#137)
* switch OT traits to generic instead of associated types * blanket OT impls for wire labels * rustfmt * make ObliviousVerify generic * OT verify * update share-conversion-aio * fixes for rebase * rebase fixes * fmt fix * mock ot Co-authored-by: sinu.eth <>
This commit is contained in:
@@ -178,7 +178,7 @@ mod test {
|
||||
|
||||
let send = async { sender.send(data).await.unwrap() };
|
||||
|
||||
let receive = async { receiver.receive(&choices).await.unwrap() };
|
||||
let receive = async { receiver.receive(choices).await.unwrap() };
|
||||
|
||||
let (_, received) = futures::join!(send, receive);
|
||||
|
||||
@@ -287,7 +287,7 @@ mod test {
|
||||
let choices = vec![false; split_size];
|
||||
|
||||
let (send, receive) =
|
||||
tokio::join!(sender.send(messages.clone()), receiver.receive(&choices));
|
||||
tokio::join!(sender.send(messages.clone()), receiver.receive(choices));
|
||||
send.unwrap();
|
||||
_ = receive.unwrap();
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::{config::ReceiverFactoryConfig, GetReceiver, Setup};
|
||||
use mpc_core::{
|
||||
msgs::ot::{OTFactoryMessage, OTMessage, Split},
|
||||
ot::r_state::RandSetup,
|
||||
Block,
|
||||
};
|
||||
use utils_aio::{mux::MuxChannelControl, Channel};
|
||||
|
||||
@@ -241,7 +242,7 @@ impl<T, S> ReceiverFactoryControl<T>
|
||||
where
|
||||
T: Handler<Setup, Return = Result<(), OTFactoryError>>
|
||||
+ Handler<GetReceiver, Return = oneshot::Receiver<Result<S, OTFactoryError>>>,
|
||||
S: ObliviousReceive,
|
||||
S: ObliviousReceive<bool, Block>,
|
||||
{
|
||||
pub fn new(addr: Address<T>) -> Self {
|
||||
Self(addr)
|
||||
@@ -272,7 +273,7 @@ where
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> OTReceiverFactory for ReceiverFactoryControl<KOSReceiverFactory<T, U>>
|
||||
impl<T, U> OTReceiverFactory<bool, Block> for ReceiverFactoryControl<KOSReceiverFactory<T, U>>
|
||||
where
|
||||
T: Channel<OTFactoryMessage, Error = std::io::Error> + Send + 'static,
|
||||
U: MuxChannelControl<OTMessage> + Send + 'static,
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::{config::SenderFactoryConfig, GetSender, Setup, Verify};
|
||||
use mpc_core::{
|
||||
msgs::ot::{OTFactoryMessage, OTMessage, Split},
|
||||
ot::s_state::RandSetup,
|
||||
Block,
|
||||
};
|
||||
use utils_aio::{mux::MuxChannelControl, Channel};
|
||||
|
||||
@@ -197,7 +198,7 @@ impl<T, S> SenderFactoryControl<T>
|
||||
where
|
||||
T: Handler<Setup, Return = Result<(), OTFactoryError>>
|
||||
+ Handler<GetSender, Return = Result<S, OTFactoryError>>,
|
||||
S: ObliviousSend,
|
||||
S: ObliviousSend<[Block; 2]>,
|
||||
{
|
||||
pub fn new(addr: Address<T>) -> Self {
|
||||
Self(addr)
|
||||
@@ -226,7 +227,7 @@ where
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> OTSenderFactory for SenderFactoryControl<KOSSenderFactory<T, U>>
|
||||
impl<T, U> OTSenderFactory<[Block; 2]> for SenderFactoryControl<KOSSenderFactory<T, U>>
|
||||
where
|
||||
T: Channel<OTFactoryMessage, Error = std::io::Error> + Send + 'static,
|
||||
U: MuxChannelControl<OTMessage> + Send + 'static,
|
||||
|
||||
@@ -8,23 +8,23 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::protocol::garble::{
|
||||
label::{WireLabelOTReceive, WireLabelOTSend},
|
||||
Evaluator, ExecuteWithLabels, GCError, GarbleChannel, GarbleMessage, Generator,
|
||||
use crate::protocol::{
|
||||
garble::{Evaluator, ExecuteWithLabels, GCError, GarbleChannel, GarbleMessage, Generator},
|
||||
ot::{ObliviousReceive, ObliviousSend},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use mpc_circuits::{Circuit, InputValue};
|
||||
use mpc_core::garble::{
|
||||
exec::dual as core, gc_state, Delta, GarbledCircuit, InputLabels, WireLabelPair,
|
||||
exec::dual as core, gc_state, Delta, GarbledCircuit, InputLabels, WireLabel, WireLabelPair,
|
||||
};
|
||||
use utils_aio::expect_msg_or_err;
|
||||
|
||||
pub struct DualExLeader<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator,
|
||||
S: WireLabelOTSend,
|
||||
R: WireLabelOTReceive,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>>,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>>,
|
||||
{
|
||||
channel: GarbleChannel,
|
||||
backend: B,
|
||||
@@ -35,8 +35,8 @@ where
|
||||
impl<B, S, R> DualExLeader<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>> + Send,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>> + Send,
|
||||
{
|
||||
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
|
||||
Self {
|
||||
@@ -52,8 +52,8 @@ where
|
||||
impl<B, S, R> ExecuteWithLabels for DualExLeader<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>> + Send,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>> + Send,
|
||||
{
|
||||
async fn execute_with_labels(
|
||||
&mut self,
|
||||
@@ -84,7 +84,7 @@ where
|
||||
.cloned()
|
||||
.collect::<Vec<InputLabels<WireLabelPair>>>();
|
||||
|
||||
self.label_sender.send_labels(follower_labels).await?;
|
||||
self.label_sender.send(follower_labels).await?;
|
||||
|
||||
let msg = expect_msg_or_err!(
|
||||
self.channel.next().await,
|
||||
@@ -93,7 +93,7 @@ where
|
||||
)?;
|
||||
|
||||
let gc_ev = GarbledCircuit::<gc_state::Partial>::from_unchecked(circ, msg.into())?;
|
||||
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
|
||||
let labels_ev = self.label_receiver.receive(inputs.to_vec()).await?;
|
||||
|
||||
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
|
||||
let leader = leader.from_evaluated_circuit(evaluated_gc)?;
|
||||
@@ -124,8 +124,8 @@ where
|
||||
pub struct DualExFollower<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator,
|
||||
S: WireLabelOTSend,
|
||||
R: WireLabelOTReceive,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>>,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>>,
|
||||
{
|
||||
channel: GarbleChannel,
|
||||
backend: B,
|
||||
@@ -136,8 +136,8 @@ where
|
||||
impl<B, S, R> DualExFollower<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>> + Send,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>> + Send,
|
||||
{
|
||||
pub fn new(channel: GarbleChannel, backend: B, label_sender: S, label_receiver: R) -> Self {
|
||||
Self {
|
||||
@@ -153,8 +153,8 @@ where
|
||||
impl<B, S, R> ExecuteWithLabels for DualExFollower<B, S, R>
|
||||
where
|
||||
B: Generator + Evaluator + Send,
|
||||
S: WireLabelOTSend + Send,
|
||||
R: WireLabelOTReceive + Send,
|
||||
S: ObliviousSend<InputLabels<WireLabelPair>> + Send,
|
||||
R: ObliviousReceive<InputValue, InputLabels<WireLabel>> + Send,
|
||||
{
|
||||
async fn execute_with_labels(
|
||||
&mut self,
|
||||
@@ -185,7 +185,7 @@ where
|
||||
.cloned()
|
||||
.collect::<Vec<InputLabels<WireLabelPair>>>();
|
||||
|
||||
self.label_sender.send_labels(leader_labels).await?;
|
||||
self.label_sender.send(leader_labels).await?;
|
||||
|
||||
let msg = expect_msg_or_err!(
|
||||
self.channel.next().await,
|
||||
@@ -194,7 +194,7 @@ where
|
||||
)?;
|
||||
|
||||
let gc_ev = GarbledCircuit::<gc_state::Partial>::from_unchecked(circ, msg.into())?;
|
||||
let labels_ev = self.label_receiver.receive_labels(inputs.to_vec()).await?;
|
||||
let labels_ev = self.label_receiver.receive(inputs.to_vec()).await?;
|
||||
|
||||
let evaluated_gc = self.backend.evaluate(gc_ev, &labels_ev).await?;
|
||||
let follower = follower.from_evaluated_circuit(evaluated_gc)?;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::protocol::ot;
|
||||
use crate::protocol::ot::{OTError, ObliviousReceive, ObliviousSend, ObliviousVerify};
|
||||
use async_trait::async_trait;
|
||||
use mpc_circuits::InputValue;
|
||||
use mpc_core::{
|
||||
@@ -9,58 +9,46 @@ use mpc_core::{
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum WireLabelError {
|
||||
#[error("error occurred during OT")]
|
||||
OTError(#[from] ot::OTError),
|
||||
OTError(#[from] OTError),
|
||||
#[error("core error")]
|
||||
CoreError(#[from] mpc_core::garble::Error),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait WireLabelOTSend: ot::ObliviousSend<Inputs = Vec<[Block; 2]>> {
|
||||
/// Sends labels using oblivious transfer
|
||||
///
|
||||
/// Inputs must be provided sorted ascending by input id
|
||||
async fn send_labels(
|
||||
&mut self,
|
||||
inputs: Vec<InputLabels<WireLabelPair>>,
|
||||
) -> Result<(), WireLabelError> {
|
||||
impl<T> ObliviousSend<InputLabels<WireLabelPair>> for T
|
||||
where
|
||||
T: Send + ObliviousSend<[Block; 2]>,
|
||||
{
|
||||
async fn send(&mut self, inputs: Vec<InputLabels<WireLabelPair>>) -> Result<(), OTError> {
|
||||
self.send(
|
||||
inputs
|
||||
.into_iter()
|
||||
.map(|labels| {
|
||||
labels
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|pair| [*pair.low(), *pair.high()])
|
||||
.collect::<Vec<[Block; 2]>>()
|
||||
})
|
||||
.map(|labels| labels.to_blocks())
|
||||
.flatten()
|
||||
.collect::<Vec<[Block; 2]>>(),
|
||||
)
|
||||
.await
|
||||
.map_err(WireLabelError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WireLabelOTSend for T where T: ot::ObliviousSend<Inputs = Vec<[Block; 2]>> {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait WireLabelOTReceive: ot::ObliviousReceive<Choice = bool, Outputs = Vec<Block>> {
|
||||
/// Receives labels using oblivious transfer
|
||||
///
|
||||
/// Inputs must be provided sorted ascending by input id
|
||||
async fn receive_labels(
|
||||
impl<T> ObliviousReceive<InputValue, InputLabels<WireLabel>> for T
|
||||
where
|
||||
T: Send + ObliviousReceive<bool, Block>,
|
||||
{
|
||||
async fn receive(
|
||||
&mut self,
|
||||
inputs: Vec<InputValue>,
|
||||
) -> Result<Vec<InputLabels<WireLabel>>, WireLabelError> {
|
||||
let choices = inputs
|
||||
choices: Vec<InputValue>,
|
||||
) -> Result<Vec<InputLabels<WireLabel>>, OTError> {
|
||||
let choice_bits = choices
|
||||
.iter()
|
||||
.map(|value| value.wire_values())
|
||||
.flatten()
|
||||
.collect::<Vec<bool>>();
|
||||
|
||||
let mut labels = self.receive(&choices).await?;
|
||||
let mut labels = self.receive(choice_bits).await?;
|
||||
|
||||
inputs
|
||||
Ok(choices
|
||||
.into_iter()
|
||||
.map(|value| {
|
||||
InputLabels::new(
|
||||
@@ -71,13 +59,28 @@ pub trait WireLabelOTReceive: ot::ObliviousReceive<Choice = bool, Outputs = Vec<
|
||||
.map(|(block, id)| WireLabel::new(*id, block))
|
||||
.collect::<Vec<WireLabel>>(),
|
||||
)
|
||||
.map_err(WireLabelError::from)
|
||||
.expect("Input labels should be valid")
|
||||
})
|
||||
.collect::<Result<Vec<InputLabels<WireLabel>>, WireLabelError>>()
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WireLabelOTReceive for T where T: ot::ObliviousReceive<Choice = bool, Outputs = Vec<Block>> {}
|
||||
#[async_trait]
|
||||
impl<T> ObliviousVerify<InputLabels<WireLabelPair>> for T
|
||||
where
|
||||
T: Send + ObliviousVerify<[Block; 2]>,
|
||||
{
|
||||
async fn verify(self, input: Vec<InputLabels<WireLabelPair>>) -> Result<(), OTError> {
|
||||
self.verify(
|
||||
input
|
||||
.into_iter()
|
||||
.map(|labels| labels.to_blocks())
|
||||
.flatten()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -95,8 +98,8 @@ mod tests {
|
||||
let expected = receiver_labels[0].select(&value).unwrap();
|
||||
|
||||
let (mut sender, mut receiver) = mock_ot_pair::<Block>();
|
||||
sender.send_labels(receiver_labels).await.unwrap();
|
||||
let received = receiver.receive_labels(vec![value]).await.unwrap();
|
||||
sender.send(receiver_labels).await.unwrap();
|
||||
let received = receiver.receive(vec![value]).await.unwrap();
|
||||
|
||||
assert_eq!(received[0].as_ref(), expected.as_ref());
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ use mpc_core::{
|
||||
use rand::thread_rng;
|
||||
use utils_aio::Channel;
|
||||
|
||||
use super::ot::OTError;
|
||||
|
||||
pub type GarbleChannel = Box<dyn Channel<GarbleMessage, Error = std::io::Error>>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -26,7 +28,7 @@ pub enum GCError {
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("ot error")]
|
||||
LabelOTError(#[from] label::WireLabelError),
|
||||
OTError(#[from] OTError),
|
||||
#[error("Received unexpected message: {0:?}")]
|
||||
Unexpected(GarbleMessage),
|
||||
#[error("backend error")]
|
||||
|
||||
@@ -35,7 +35,7 @@ mod tests {
|
||||
});
|
||||
let receive = tokio::spawn(async move {
|
||||
let mut receiver = receiver.rand_setup(ITERATIONS).await.unwrap();
|
||||
receiver.receive(&choices).await.unwrap()
|
||||
receiver.receive(choices).await.unwrap()
|
||||
});
|
||||
|
||||
let (_, output) = tokio::join!(send, receive);
|
||||
@@ -70,7 +70,7 @@ mod tests {
|
||||
let receive = tokio::spawn(async move {
|
||||
receiver.accept_commit().await.unwrap();
|
||||
let mut receiver = receiver.rand_setup(ITERATIONS).await.unwrap();
|
||||
let ot_output = receiver.receive(&choices).await.unwrap();
|
||||
let ot_output = receiver.receive(choices).await.unwrap();
|
||||
let verification = receiver.verify(blocks).await;
|
||||
(ot_output, verification)
|
||||
});
|
||||
|
||||
@@ -93,11 +93,8 @@ impl Kos15IOReceiver<r_state::RandSetup> {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ObliviousReceive for Kos15IOReceiver<r_state::RandSetup> {
|
||||
type Choice = bool;
|
||||
type Outputs = Vec<Block>;
|
||||
|
||||
async fn receive(&mut self, choices: &[bool]) -> Result<Self::Outputs, OTError> {
|
||||
impl ObliviousReceive<bool, Block> for Kos15IOReceiver<r_state::RandSetup> {
|
||||
async fn receive(&mut self, choices: Vec<bool>) -> Result<Vec<Block>, OTError> {
|
||||
let message = self.inner.derandomize(&choices)?;
|
||||
self.channel
|
||||
.send(OTMessage::ExtDerandomize(message))
|
||||
@@ -127,10 +124,8 @@ impl ObliviousAcceptCommit for Kos15IOReceiver<r_state::Initialized> {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ObliviousVerify for Kos15IOReceiver<r_state::RandSetup> {
|
||||
type Input = [Block; 2];
|
||||
|
||||
async fn verify(mut self, input: Vec<Self::Input>) -> Result<(), OTError> {
|
||||
impl ObliviousVerify<[Block; 2]> for Kos15IOReceiver<r_state::RandSetup> {
|
||||
async fn verify(mut self, input: Vec<[Block; 2]>) -> Result<(), OTError> {
|
||||
let reveal = expect_msg_or_err!(
|
||||
self.channel.next().await,
|
||||
OTMessage::ExtSenderReveal,
|
||||
|
||||
@@ -105,10 +105,8 @@ impl Kos15IOSender<s_state::RandSetup> {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ObliviousSend for Kos15IOSender<s_state::RandSetup> {
|
||||
type Inputs = Vec<[Block; 2]>;
|
||||
|
||||
async fn send(&mut self, inputs: Self::Inputs) -> Result<(), OTError> {
|
||||
impl ObliviousSend<[Block; 2]> for Kos15IOSender<s_state::RandSetup> {
|
||||
async fn send(&mut self, inputs: Vec<[Block; 2]>) -> Result<(), OTError> {
|
||||
let message = expect_msg_or_err!(
|
||||
self.channel.next().await,
|
||||
OTMessage::ExtDerandomize,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{
|
||||
OTError, OTFactoryError, OTReceiverFactory, OTSenderFactory, ObliviousReceive, ObliviousSend,
|
||||
OTError, OTFactoryError, OTReceiverFactory, OTSenderFactory, ObliviousReceive, ObliviousReveal,
|
||||
ObliviousSend, ObliviousVerify,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, StreamExt};
|
||||
@@ -13,7 +14,7 @@ pub struct MockOTFactory<T> {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send + 'static> OTSenderFactory for Arc<Mutex<MockOTFactory<T>>> {
|
||||
impl<T: Send + 'static> OTSenderFactory<[T; 2]> for Arc<Mutex<MockOTFactory<T>>> {
|
||||
type Protocol = MockOTSender<T>;
|
||||
|
||||
async fn new_sender(
|
||||
@@ -33,7 +34,7 @@ impl<T: Send + 'static> OTSenderFactory for Arc<Mutex<MockOTFactory<T>>> {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send + 'static> OTReceiverFactory for Arc<Mutex<MockOTFactory<T>>> {
|
||||
impl<T: Send + 'static> OTReceiverFactory<bool, T> for Arc<Mutex<MockOTFactory<T>>> {
|
||||
type Protocol = MockOTReceiver<T>;
|
||||
|
||||
async fn new_receiver(
|
||||
@@ -66,13 +67,11 @@ pub fn mock_ot_pair<T: Send + 'static>() -> (MockOTSender<T>, MockOTReceiver<T>)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ObliviousSend for MockOTSender<T>
|
||||
impl<T> ObliviousSend<[T; 2]> for MockOTSender<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
type Inputs = Vec<[T; 2]>;
|
||||
|
||||
async fn send(&mut self, inputs: Self::Inputs) -> Result<(), OTError> {
|
||||
async fn send(&mut self, inputs: Vec<[T; 2]>) -> Result<(), OTError> {
|
||||
self.sender
|
||||
.try_send(inputs)
|
||||
.expect("DummySender should be able to send");
|
||||
@@ -81,14 +80,11 @@ where
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ObliviousReceive for MockOTReceiver<T>
|
||||
impl<T> ObliviousReceive<bool, T> for MockOTReceiver<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
type Choice = bool;
|
||||
type Outputs = Vec<T>;
|
||||
|
||||
async fn receive(&mut self, choices: &[bool]) -> Result<Vec<T>, OTError> {
|
||||
async fn receive(&mut self, choices: Vec<bool>) -> Result<Vec<T>, OTError> {
|
||||
let payload = self
|
||||
.receiver
|
||||
.next()
|
||||
@@ -99,7 +95,7 @@ where
|
||||
.zip(choices)
|
||||
.map(|(v, c)| {
|
||||
let [low, high] = v;
|
||||
if *c {
|
||||
if c {
|
||||
high
|
||||
} else {
|
||||
low
|
||||
@@ -109,6 +105,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ObliviousVerify<[T; 2]> for MockOTReceiver<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
async fn verify(self, _input: Vec<[T; 2]>) -> Result<(), OTError> {
|
||||
// MockOT is always honest
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> ObliviousReveal for MockOTSender<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
async fn reveal(mut self) -> Result<(), OTError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -121,7 +138,7 @@ mod tests {
|
||||
|
||||
sender.send(values).await.unwrap();
|
||||
|
||||
let received = receiver.receive(&choice).await.unwrap();
|
||||
let received = receiver.receive(choice).await.unwrap();
|
||||
assert_eq!(received, vec![0, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,18 +49,13 @@ pub enum OTFactoryError {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ObliviousSend {
|
||||
type Inputs;
|
||||
|
||||
async fn send(&mut self, inputs: Self::Inputs) -> Result<(), OTError>;
|
||||
pub trait ObliviousSend<T> {
|
||||
async fn send(&mut self, inputs: Vec<T>) -> Result<(), OTError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ObliviousReceive {
|
||||
type Choice;
|
||||
type Outputs;
|
||||
|
||||
async fn receive(&mut self, choices: &[Self::Choice]) -> Result<Self::Outputs, OTError>;
|
||||
pub trait ObliviousReceive<T, U> {
|
||||
async fn receive(&mut self, choices: Vec<T>) -> Result<Vec<U>, OTError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -82,16 +77,14 @@ pub trait ObliviousAcceptCommit {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ObliviousVerify {
|
||||
type Input;
|
||||
|
||||
pub trait ObliviousVerify<T> {
|
||||
/// Verifies the correctness of the revealed OT seed
|
||||
async fn verify(self, input: Vec<Self::Input>) -> Result<(), OTError>;
|
||||
async fn verify(self, input: Vec<T>) -> Result<(), OTError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait OTSenderFactory {
|
||||
type Protocol: ObliviousSend + Send;
|
||||
pub trait OTSenderFactory<T> {
|
||||
type Protocol: ObliviousSend<T> + Send;
|
||||
|
||||
/// Constructs a new Sender
|
||||
///
|
||||
@@ -105,8 +98,8 @@ pub trait OTSenderFactory {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait OTReceiverFactory {
|
||||
type Protocol: ObliviousReceive + Send;
|
||||
pub trait OTReceiverFactory<T, U> {
|
||||
type Protocol: ObliviousReceive<T, U> + Send;
|
||||
|
||||
/// Constructs a new Receiver
|
||||
///
|
||||
@@ -124,9 +117,7 @@ mockall::mock! {
|
||||
pub ObliviousSender {}
|
||||
|
||||
#[async_trait]
|
||||
impl ObliviousSend for ObliviousSender {
|
||||
type Inputs = Vec<[mpc_core::Block; 2]>;
|
||||
|
||||
impl ObliviousSend<[mpc_core::Block; 2]> for ObliviousSender {
|
||||
async fn send(
|
||||
&mut self,
|
||||
inputs: Vec<[mpc_core::Block; 2]>,
|
||||
@@ -139,13 +130,10 @@ mockall::mock! {
|
||||
pub ObliviousReceiver {}
|
||||
|
||||
#[async_trait]
|
||||
impl ObliviousReceive for ObliviousReceiver {
|
||||
type Choice = bool;
|
||||
type Outputs = Vec<mpc_core::Block>;
|
||||
|
||||
impl ObliviousReceive<bool, mpc_core::Block> for ObliviousReceiver {
|
||||
async fn receive(
|
||||
&mut self,
|
||||
choices: &[bool],
|
||||
choices: Vec<bool>,
|
||||
) -> Result<Vec<mpc_core::Block>, OTError>;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +63,14 @@ where
|
||||
}
|
||||
|
||||
impl InputLabels<WireLabelPair> {
|
||||
/// Returns input labels in block representation
|
||||
pub fn to_blocks(self) -> Vec<[Block; 2]> {
|
||||
self.labels
|
||||
.into_iter()
|
||||
.map(|labels| labels.to_inner())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generates a full set of input [`WireLabelPair`] for the provided [`Circuit`]
|
||||
pub fn generate<R: Rng + CryptoRng>(
|
||||
rng: &mut R,
|
||||
|
||||
@@ -65,6 +65,12 @@ impl WireLabel {
|
||||
Self { id, value }
|
||||
}
|
||||
|
||||
/// Returns inner block
|
||||
#[inline]
|
||||
pub fn to_inner(self) -> Block {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Returns wire id of label
|
||||
#[inline]
|
||||
pub fn id(&self) -> usize {
|
||||
@@ -144,6 +150,12 @@ impl WireLabelPair {
|
||||
Self { id, low, high }
|
||||
}
|
||||
|
||||
/// Returns inner blocks
|
||||
#[inline]
|
||||
pub fn to_inner(self) -> [Block; 2] {
|
||||
[self.low, self.high]
|
||||
}
|
||||
|
||||
/// Generates pairs of wire labels \[W_0, W_0 ^ delta\]
|
||||
pub fn generate<R: Rng + CryptoRng>(
|
||||
rng: &mut R,
|
||||
|
||||
@@ -13,7 +13,7 @@ use mpc_core::Block;
|
||||
/// The receiver for the conversion
|
||||
///
|
||||
/// Will be the OT receiver
|
||||
pub struct Receiver<T: OTReceiverFactory, U: Gf2_128ShareConvert, V = Void> {
|
||||
pub struct Receiver<T: OTReceiverFactory<bool, Block>, U: Gf2_128ShareConvert, V = Void> {
|
||||
/// Provides initialized OTs for the OT receiver
|
||||
receiver_factory: T,
|
||||
id: String,
|
||||
@@ -24,12 +24,8 @@ pub struct Receiver<T: OTReceiverFactory, U: Gf2_128ShareConvert, V = Void> {
|
||||
counter: usize,
|
||||
}
|
||||
|
||||
impl<
|
||||
T: OTReceiverFactory<Protocol = U> + Send,
|
||||
U: ObliviousReceive<Choice = bool, Outputs = Vec<Block>>,
|
||||
V: Gf2_128ShareConvert,
|
||||
W: Recorder<V>,
|
||||
> Receiver<T, V, W>
|
||||
impl<T: OTReceiverFactory<bool, Block> + Send, V: Gf2_128ShareConvert, W: Recorder<V>>
|
||||
Receiver<T, V, W>
|
||||
{
|
||||
/// Create a new receiver
|
||||
pub fn new(receiver_factory: T, id: String, channel: Gf2ConversionChannel) -> Self {
|
||||
@@ -64,7 +60,7 @@ impl<
|
||||
.await?;
|
||||
|
||||
self.counter += 1;
|
||||
let ot_output = ot_receiver.receive(&choices).await?;
|
||||
let ot_output = ot_receiver.receive(choices).await?;
|
||||
|
||||
// Aggregate chunks of OTs to get back u128 values
|
||||
let converted_shares = ot_output
|
||||
@@ -80,11 +76,8 @@ impl<
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<
|
||||
T: OTReceiverFactory<Protocol = U> + Send,
|
||||
U: ObliviousReceive<Choice = bool, Outputs = Vec<Block>> + Send,
|
||||
V: Recorder<AddShare> + Send,
|
||||
> AdditiveToMultiplicative for Receiver<T, AddShare, V>
|
||||
impl<T: OTReceiverFactory<bool, Block> + Send, V: Recorder<AddShare> + Send>
|
||||
AdditiveToMultiplicative for Receiver<T, AddShare, V>
|
||||
{
|
||||
type FieldElement = u128;
|
||||
|
||||
@@ -99,11 +92,8 @@ impl<
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<
|
||||
T: OTReceiverFactory<Protocol = U> + Send,
|
||||
U: ObliviousReceive<Choice = bool, Outputs = Vec<Block>> + Send,
|
||||
V: Recorder<MulShare> + Send,
|
||||
> MultiplicativeToAdditive for Receiver<T, MulShare, V>
|
||||
impl<T: OTReceiverFactory<bool, Block> + Send, V: Recorder<MulShare> + Send>
|
||||
MultiplicativeToAdditive for Receiver<T, MulShare, V>
|
||||
{
|
||||
type FieldElement = u128;
|
||||
|
||||
@@ -120,7 +110,7 @@ impl<
|
||||
#[async_trait]
|
||||
impl<T, U> VerifyTape for Receiver<T, U, Tape>
|
||||
where
|
||||
T: OTReceiverFactory + Send,
|
||||
T: OTReceiverFactory<bool, Block> + Send,
|
||||
U: Gf2_128ShareConvert + Send,
|
||||
{
|
||||
async fn verify_tape(mut self) -> Result<(), ShareConversionError> {
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConversionE
|
||||
use async_trait::async_trait;
|
||||
use futures::SinkExt;
|
||||
use mpc_aio::protocol::ot::{OTSenderFactory, ObliviousSend};
|
||||
use mpc_core::Block;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha12Rng;
|
||||
use utils_aio::adaptive_barrier::AdaptiveBarrier;
|
||||
@@ -17,7 +18,7 @@ use utils_aio::adaptive_barrier::AdaptiveBarrier;
|
||||
/// Will be the OT sender
|
||||
pub struct Sender<T, U, V = Void>
|
||||
where
|
||||
T: OTSenderFactory,
|
||||
T: OTSenderFactory<[Block; 2]>,
|
||||
U: Gf2_128ShareConvert,
|
||||
V: Recorder<U>,
|
||||
{
|
||||
@@ -35,8 +36,7 @@ where
|
||||
|
||||
impl<T, U, V> Sender<T, U, V>
|
||||
where
|
||||
T: OTSenderFactory + Send,
|
||||
<<T as OTSenderFactory>::Protocol as ObliviousSend>::Inputs: From<OTEnvelope> + Send,
|
||||
T: OTSenderFactory<[Block; 2]> + Send,
|
||||
U: Gf2_128ShareConvert,
|
||||
V: Recorder<U>,
|
||||
{
|
||||
@@ -96,8 +96,7 @@ where
|
||||
#[cfg(test)]
|
||||
impl<T, U> Sender<T, U, Tape>
|
||||
where
|
||||
T: OTSenderFactory + Send,
|
||||
<<T as OTSenderFactory>::Protocol as ObliviousSend>::Inputs: From<OTEnvelope> + Send,
|
||||
T: OTSenderFactory<[Block; 2]> + Send,
|
||||
U: Gf2_128ShareConvert,
|
||||
{
|
||||
pub fn tape_mut(&mut self) -> &mut Tape {
|
||||
@@ -108,8 +107,7 @@ where
|
||||
#[async_trait]
|
||||
impl<T, V> AdditiveToMultiplicative for Sender<T, AddShare, V>
|
||||
where
|
||||
T: OTSenderFactory + Send,
|
||||
<<T as OTSenderFactory>::Protocol as ObliviousSend>::Inputs: From<OTEnvelope> + Send,
|
||||
T: OTSenderFactory<[Block; 2]> + Send,
|
||||
V: Recorder<AddShare> + Send,
|
||||
{
|
||||
type FieldElement = u128;
|
||||
@@ -127,8 +125,7 @@ where
|
||||
#[async_trait]
|
||||
impl<T, V> MultiplicativeToAdditive for Sender<T, MulShare, V>
|
||||
where
|
||||
T: OTSenderFactory + Send,
|
||||
<<T as OTSenderFactory>::Protocol as ObliviousSend>::Inputs: From<OTEnvelope> + Send,
|
||||
T: OTSenderFactory<[Block; 2]> + Send,
|
||||
V: Recorder<MulShare> + Send,
|
||||
{
|
||||
type FieldElement = u128;
|
||||
@@ -146,7 +143,7 @@ where
|
||||
#[async_trait]
|
||||
impl<T, U> SendTape for Sender<T, U, Tape>
|
||||
where
|
||||
T: OTSenderFactory + Send,
|
||||
T: OTSenderFactory<[Block; 2]> + Send,
|
||||
U: Gf2_128ShareConvert + Send,
|
||||
{
|
||||
async fn send_tape(mut self) -> Result<(), ShareConversionError> {
|
||||
|
||||
Reference in New Issue
Block a user