refactor: sans-io TLS IO (#1036)

* refactor: add tls-client trait

* cleanup

* fix clippy

* add start state

* assert that mpc future is pending
This commit is contained in:
th4s
2025-11-25 18:13:11 +01:00
parent 4d449b2a1d
commit 0bf4a857b9
12 changed files with 792 additions and 189 deletions

View File

@@ -378,15 +378,29 @@ impl MpcTlsLeader {
Ok(())
}
/// Defers decryption of any incoming messages.
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
#[instrument(level = "debug", skip_all, err)]
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
self.is_decrypting = false;
self.notifier.clear();
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = enable;
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
Ok(())
}
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop();

View File

@@ -32,10 +32,14 @@ impl MpcTlsLeaderCtrl {
Self { address }
}
/// Defers decryption of any incoming messages.
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> {
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
self.address
.send(DeferDecryption)
.send(EnableDecryption { enable })
.await
.map_err(MpcTlsError::actor)?
}
@@ -981,7 +985,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
}
}
impl Dispatch<MpcTlsLeader> for DeferDecryption {
impl Dispatch<MpcTlsLeader> for EnableDecryption {
fn dispatch<R: FnOnce(Self::Return) + Send>(
self,
actor: &mut MpcTlsLeader,
@@ -992,13 +996,13 @@ impl Dispatch<MpcTlsLeader> for DeferDecryption {
}
}
impl Handler<DeferDecryption> for MpcTlsLeader {
impl Handler<EnableDecryption> for MpcTlsLeader {
async fn handle(
&mut self,
_msg: DeferDecryption,
msg: EnableDecryption,
_ctx: &mut LudiCtx<Self>,
) -> <DeferDecryption as Message>::Return {
self.defer_decryption().await
) -> <EnableDecryption as Message>::Return {
self.enable_decryption(msg.enable)
}
}
@@ -1048,7 +1052,7 @@ pub enum MpcTlsLeaderMsg {
BackendMsgGetNotify(BackendMsgGetNotify),
BackendMsgIsEmpty(BackendMsgIsEmpty),
BackendMsgServerClosed(BackendMsgServerClosed),
DeferDecryption(DeferDecryption),
DeferDecryption(EnableDecryption),
Stop(Stop),
}
@@ -1083,7 +1087,7 @@ pub enum MpcTlsLeaderMsgReturn {
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
DeferDecryption(<DeferDecryption as Message>::Return),
DeferDecryption(<EnableDecryption as Message>::Return),
Stop(<Stop as Message>::Return),
}
@@ -1732,23 +1736,25 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
}
}
/// Message to start deferring the decryption.
/// Message to enable or disable the decryption of messages.
#[allow(missing_docs)]
#[derive(Debug)]
pub struct DeferDecryption;
pub struct EnableDecryption {
pub enable: bool,
}
impl Message for DeferDecryption {
impl Message for EnableDecryption {
type Return = Result<(), MpcTlsError>;
}
impl From<DeferDecryption> for MpcTlsLeaderMsg {
fn from(value: DeferDecryption) -> Self {
impl From<EnableDecryption> for MpcTlsLeaderMsg {
fn from(value: EnableDecryption) -> Self {
MpcTlsLeaderMsg::DeferDecryption(value)
}
}
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg {
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> {
impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
match ret {
Self::Return::DeferDecryption(value) => Ok(value),
_ => Err(Error::Wrapper),

View File

@@ -88,7 +88,7 @@ async fn leader_task(mut leader: MpcTlsLeader) {
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
leader_ctrl.enable_decryption(false).await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",

View File

@@ -197,8 +197,8 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
.write_all_plaintext(&data)?;
client.process_new_packets().await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {

View File

@@ -690,6 +690,11 @@ impl CommonState {
self.received_plaintext.is_empty()
}
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS
/// handshake.
///

View File

@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
self.chunks.is_empty()
}
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing
pub(crate) fn len(&self) -> usize {
let mut len = 0;

View File

@@ -1,12 +1,11 @@
//! Prover.
mod client;
mod error;
mod future;
mod prove;
pub mod state;
pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput;
use crate::{
@@ -15,17 +14,17 @@ use crate::{
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
prover::client::{MpcTlsClient, TlsOutput},
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc;
use std::{
sync::Arc,
task::{Context, Poll},
};
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
config::{
prove::ProveConfig,
@@ -36,10 +35,9 @@ use tlsn_core::{
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -133,31 +131,29 @@ impl Prover<state::Initialized> {
}
impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket.
/// Connects the prover.
///
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed and the TLS transcript is
/// committed.
/// Returns a connected prover, which can be used to read and write from/to
/// the active TLS connection.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn connect(
self,
config: TlsClientConfig,
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
) -> Result<Prover<state::Connected>, ProverError> {
let state::CommitAccepted {
mux_ctrl,
mut mux_fut,
mux_fut,
mpc_tls,
keys,
vm,
..
} = self.state;
let decrypt = mpc_tls.is_decrypting();
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = config.server_name();
@@ -202,95 +198,160 @@ impl Prover<state::CommitAccepted> {
)
.map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client);
let span = self.span.clone();
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
let mpc_tls = MpcTlsClient::new(
Box::new(mpc_fut.map_err(ProverError::from)),
keys,
vm,
span,
mpc_ctrl,
client,
decrypt,
);
mpc_ctrl.stop().await?;
Ok::<_, ProverError>(())
};
info!("starting MPC-TLS");
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
info!("finished MPC-TLS");
{
let mut vm = vm.try_lock().expect("VM should not be locked");
debug!("finalizing mpc");
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
.instrument(span)
});
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
let prover = Prover {
config: self.config,
span: self.span,
state: state::Connected {
mux_ctrl,
mux_fut,
server_name: config.server_name().clone(),
tls_client: Box::new(mpc_tls),
output: None,
},
))
};
Ok(prover)
}
}
impl Prover<state::Connected> {
/// Returns `true` if the prover wants to read TLS data from the server.
pub fn wants_read_tls(&self) -> bool {
self.state.tls_client.wants_read_tls()
}
/// Returns `true` if the prover wants to write TLS data to the server.
pub fn wants_write_tls(&self) -> bool {
self.state.tls_client.wants_write_tls()
}
/// Reads TLS data from the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the TLS data from.
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.read_tls(buf)
}
/// Writes TLS data for the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer to write the TLS data to.
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.write_tls(buf)
}
/// Returns `true` if the prover wants to read plaintext data.
pub fn wants_read(&self) -> bool {
self.state.tls_client.wants_read()
}
/// Returns `true` if the prover wants to write plaintext data.
pub fn wants_write(&self) -> bool {
self.state.tls_client.wants_write()
}
/// Reads plaintext data from the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer where the plaintext data gets written to.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.read(buf)
}
/// Writes plaintext data to be sent to the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the plaintext data from.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.write(buf)
}
/// Closes the connection from the client side.
pub fn client_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.client_close()
}
/// Closes the connection from the server side.
pub fn server_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.server_close()
}
/// Enables or disables the decryption of data from the server until the
/// server has closed the connection.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
self.state.tls_client.enable_decryption(enable)
}
/// Returns `true` if decryption of TLS traffic from the server is active.
pub fn is_decrypting(&self) -> bool {
self.state.tls_client.is_decrypting()
}
/// Polls the prover to make progress.
///
/// # Arguments
///
/// * `cx` - The async context.
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
let _ = self.state.mux_fut.poll_unpin(cx)?;
match self.state.tls_client.poll(cx)? {
Poll::Ready(output) => {
self.state.output = Some(output);
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
},
};
Ok(prover)
}
}
@@ -379,29 +440,3 @@ impl Prover<state::Committed> {
Ok(())
}
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -0,0 +1,63 @@
//! Provides a TLS client.
use crate::mpz::ProverZk;
use mpc_tls::SessionKeys;
use std::task::{Context, Poll};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Enables or disables decryption of TLS traffic sent by the server.
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
/// Returns `true` if decryption of TLS traffic from the server is active.
fn is_decrypting(&self) -> bool;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -0,0 +1,477 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{LeaderCtrl, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{pin::Pin, sync::Arc, task::Poll};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture = Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>>>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>>>;
pub(crate) struct MpcTlsClient {
state: State,
decrypt: bool,
}
enum State {
Start {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
},
Active {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
},
Busy {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
},
ClientClose {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
},
ServerClose {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
},
Closing {
ctx: Context,
transcript: Box<TlsTranscript>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
mpc: MpcFuture,
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
mpc_ctrl: LeaderCtrl,
tls: ClientConnection,
decrypt: bool,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_ctrl,
closed: false,
};
Self {
decrypt,
state: State::Start {
mpc: Box::into_pin(mpc),
inner: Box::new(inner),
},
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
match std::mem::replace(&mut self.state, State::Error) {
State::Active { inner, mpc } => {
self.state = State::ClientClose {
mpc,
fut: Box::pin(inner.client_close()),
};
Ok(())
}
other => {
self.state = other;
Err(ProverError::state(
"unable to close connection, client is not in active state",
))
}
}
}
fn server_close(&mut self) -> Result<(), Self::Error> {
match std::mem::replace(&mut self.state, State::Error) {
State::Active { inner, mpc } => {
self.state = State::ServerClose {
mpc,
fut: Box::pin(inner.server_close()),
};
Ok(())
}
other => {
self.state = other;
Err(ProverError::state(
"unable to close connection, client is not in active state",
))
}
}
}
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
match std::mem::replace(&mut self.state, State::Error) {
State::Active { inner, mpc } => {
self.decrypt = enable;
self.state = State::Busy {
mpc,
fut: Box::pin(inner.set_decrypt(enable)),
};
Ok(())
}
other => {
self.state = other;
Err(ProverError::state(
"unable to enable decryption, client is not in active state",
))
}
}
}
fn is_decrypting(&self) -> bool {
self.decrypt
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start { mpc, inner } => {
self.state = State::Busy {
mpc,
fut: Box::pin(inner.start()),
};
self.poll(cx)
}
State::Active { mpc, inner } => {
trace!("inner client is active");
self.state = State::Busy {
mpc,
fut: Box::pin(inner.run()),
};
self.poll(cx)
}
State::Busy { mut mpc, mut fut } => {
trace!("inner client is busy");
let mpc_poll = mpc.as_mut().poll(cx)?;
assert!(
matches!(mpc_poll, Poll::Pending),
"mpc future should not be finished here"
);
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active { mpc, inner };
}
Poll::Pending => self.state = State::Busy { mpc, fut },
}
Poll::Pending
}
State::ClientClose { mut mpc, mut fut } => {
debug!("attempting to close connection clientside");
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
}
(Poll::Ready(inner), Poll::Pending) => {
self.state = State::ClientClose {
mpc,
fut: Box::pin(inner.client_close()),
};
}
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
self.state = State::Closing {
ctx,
transcript: Box::new(transcript),
fut,
};
}
(Poll::Pending, Poll::Pending) => self.state = State::ClientClose { mpc, fut },
}
self.poll(cx)
}
State::ServerClose { mut mpc, mut fut } => {
debug!("attempting to close connection serverside");
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
}
(Poll::Ready(inner), Poll::Pending) => {
self.state = State::ServerClose {
mpc,
fut: Box::pin(inner.server_close()),
};
}
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
self.state = State::Closing {
ctx,
transcript: Box::new(transcript),
fut,
};
}
(Poll::Pending, Poll::Pending) => self.state = State::ServerClose { mpc, fut },
}
self.poll(cx)
}
State::Closing {
ctx,
transcript,
mut fut,
} => {
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, *transcript)),
};
} else {
self.state = State::Closing {
ctx,
transcript,
fut,
};
}
self.poll(cx)
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
self.poll(cx)
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_ctrl: LeaderCtrl,
closed: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
self.mpc_ctrl.enable_decryption(enable).await?;
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
};
self.mpc_ctrl.stop().await?;
self.closed = true;
debug!("closed connection");
}
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
self.tls.server_closed().await?;
self.mpc_ctrl.stop().await?;
self.closed = true;
debug!("closed connection");
}
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -49,6 +49,13 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
}
#[derive(Debug)]
@@ -58,6 +65,7 @@ enum ErrorKind {
Zk,
Config,
Commit,
State,
}
impl fmt::Display for ProverError {
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
}
if let Some(source) = &self.source {
@@ -86,8 +95,8 @@ impl From<std::io::Error> for ProverError {
}
}
impl From<tls_client_async::ConnectionError> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self {
impl From<tls_client::Error> for ProverError {
fn from(e: tls_client::Error) -> Self {
Self::new(ErrorKind::Io, e)
}
}

View File

@@ -1,32 +0,0 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -14,6 +14,10 @@ use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
};
/// Entry state
@@ -33,6 +37,17 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted);
/// State during the MPC-TLS connection.
pub struct Connected {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError>>,
pub(crate) output: Option<TlsOutput>,
}
opaque_debug::implement!(Connected);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mux_ctrl: MuxControl,
@@ -52,11 +67,13 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for CommitAccepted {}
impl ProverState for Connected {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Connected {}
impl Sealed for super::Committed {}
}