mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 21:38:00 -05:00
refactor: sans-io TLS IO (#1036)
* refactor: add tls-client trait * cleanup * fix clippy * add start state * assert that mpc future is pending
This commit is contained in:
@@ -378,15 +378,29 @@ impl MpcTlsLeader {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Defers decryption of any incoming messages.
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = false;
|
||||
self.notifier.clear();
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = enable;
|
||||
|
||||
if enable {
|
||||
self.notifier.set();
|
||||
} else {
|
||||
self.notifier.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns if incoming messages are decrypted.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.is_decrypting
|
||||
}
|
||||
|
||||
/// Stops the actor.
|
||||
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
|
||||
ctx.stop();
|
||||
|
||||
@@ -32,10 +32,14 @@ impl MpcTlsLeaderCtrl {
|
||||
Self { address }
|
||||
}
|
||||
|
||||
/// Defers decryption of any incoming messages.
|
||||
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> {
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.address
|
||||
.send(DeferDecryption)
|
||||
.send(EnableDecryption { enable })
|
||||
.await
|
||||
.map_err(MpcTlsError::actor)?
|
||||
}
|
||||
@@ -981,7 +985,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<MpcTlsLeader> for DeferDecryption {
|
||||
impl Dispatch<MpcTlsLeader> for EnableDecryption {
|
||||
fn dispatch<R: FnOnce(Self::Return) + Send>(
|
||||
self,
|
||||
actor: &mut MpcTlsLeader,
|
||||
@@ -992,13 +996,13 @@ impl Dispatch<MpcTlsLeader> for DeferDecryption {
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<DeferDecryption> for MpcTlsLeader {
|
||||
impl Handler<EnableDecryption> for MpcTlsLeader {
|
||||
async fn handle(
|
||||
&mut self,
|
||||
_msg: DeferDecryption,
|
||||
msg: EnableDecryption,
|
||||
_ctx: &mut LudiCtx<Self>,
|
||||
) -> <DeferDecryption as Message>::Return {
|
||||
self.defer_decryption().await
|
||||
) -> <EnableDecryption as Message>::Return {
|
||||
self.enable_decryption(msg.enable)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1048,7 +1052,7 @@ pub enum MpcTlsLeaderMsg {
|
||||
BackendMsgGetNotify(BackendMsgGetNotify),
|
||||
BackendMsgIsEmpty(BackendMsgIsEmpty),
|
||||
BackendMsgServerClosed(BackendMsgServerClosed),
|
||||
DeferDecryption(DeferDecryption),
|
||||
DeferDecryption(EnableDecryption),
|
||||
Stop(Stop),
|
||||
}
|
||||
|
||||
@@ -1083,7 +1087,7 @@ pub enum MpcTlsLeaderMsgReturn {
|
||||
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
|
||||
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
|
||||
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
|
||||
DeferDecryption(<DeferDecryption as Message>::Return),
|
||||
DeferDecryption(<EnableDecryption as Message>::Return),
|
||||
Stop(<Stop as Message>::Return),
|
||||
}
|
||||
|
||||
@@ -1732,23 +1736,25 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
|
||||
}
|
||||
}
|
||||
|
||||
/// Message to start deferring the decryption.
|
||||
/// Message to enable or disable the decryption of messages.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
pub struct DeferDecryption;
|
||||
pub struct EnableDecryption {
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
impl Message for DeferDecryption {
|
||||
impl Message for EnableDecryption {
|
||||
type Return = Result<(), MpcTlsError>;
|
||||
}
|
||||
|
||||
impl From<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: DeferDecryption) -> Self {
|
||||
impl From<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: EnableDecryption) -> Self {
|
||||
MpcTlsLeaderMsg::DeferDecryption(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> {
|
||||
impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
|
||||
match ret {
|
||||
Self::Return::DeferDecryption(value) => Ok(value),
|
||||
_ => Err(Error::Wrapper),
|
||||
|
||||
@@ -88,7 +88,7 @@ async fn leader_task(mut leader: MpcTlsLeader) {
|
||||
let mut buf = vec![0u8; 48];
|
||||
conn.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.defer_decryption().await.unwrap();
|
||||
leader_ctrl.enable_decryption(false).await.unwrap();
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
|
||||
@@ -197,8 +197,8 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
|
||||
|
||||
sent.extend(&data);
|
||||
client
|
||||
.write_all_plaintext(&data)
|
||||
.await?;
|
||||
.write_all_plaintext(&data)?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
} else {
|
||||
|
||||
@@ -690,6 +690,11 @@ impl CommonState {
|
||||
self.received_plaintext.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the buffer for sendable plaintext is full.
|
||||
pub fn sendable_plaintext_is_full(&self) -> bool {
|
||||
self.sendable_plaintext.is_full()
|
||||
}
|
||||
|
||||
/// Returns true if the connection is currently performing the TLS
|
||||
/// handshake.
|
||||
///
|
||||
|
||||
@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
|
||||
self.chunks.is_empty()
|
||||
}
|
||||
|
||||
/// If the buffer has reached limit.
|
||||
pub(crate) fn is_full(&self) -> bool {
|
||||
if let Some(limit) = self.limit {
|
||||
self.len() >= limit
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// How many bytes we're storing
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
let mut len = 0;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
//! Prover.
|
||||
|
||||
mod client;
|
||||
mod error;
|
||||
mod future;
|
||||
mod prove;
|
||||
pub mod state;
|
||||
|
||||
pub use error::ProverError;
|
||||
pub use future::ProverFuture;
|
||||
pub use tlsn_core::ProverOutput;
|
||||
|
||||
use crate::{
|
||||
@@ -15,17 +14,17 @@ use crate::{
|
||||
mpz::{ProverDeps, build_prover_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
prover::client::{MpcTlsClient, TlsOutput},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::LeaderCtrl;
|
||||
use mpz_vm_core::prelude::*;
|
||||
use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{TlsConnection, bind_client};
|
||||
use tlsn_core::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -36,10 +35,9 @@ use tlsn_core::{
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::{TlsTranscript, Transcript},
|
||||
};
|
||||
use tracing::{Span, debug, info_span, instrument};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
use tracing::{Instrument, Span, debug, info, info_span, instrument};
|
||||
|
||||
/// A prover instance.
|
||||
#[derive(Debug)]
|
||||
pub struct Prover<T: state::ProverState = state::Initialized> {
|
||||
@@ -133,31 +131,29 @@ impl Prover<state::Initialized> {
|
||||
}
|
||||
|
||||
impl Prover<state::CommitAccepted> {
|
||||
/// Connects to the server using the provided socket.
|
||||
/// Connects the prover.
|
||||
///
|
||||
/// Returns a handle to the TLS connection, a future which returns the
|
||||
/// prover once the connection is closed and the TLS transcript is
|
||||
/// committed.
|
||||
/// Returns a connected prover, which can be used to read and write from/to
|
||||
/// the active TLS connection.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS client configuration.
|
||||
/// * `socket` - The socket to the server.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
pub async fn connect(
|
||||
self,
|
||||
config: TlsClientConfig,
|
||||
socket: S,
|
||||
) -> Result<(TlsConnection, ProverFuture), ProverError> {
|
||||
) -> Result<Prover<state::Connected>, ProverError> {
|
||||
let state::CommitAccepted {
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
keys,
|
||||
vm,
|
||||
..
|
||||
} = self.state;
|
||||
|
||||
let decrypt = mpc_tls.is_decrypting();
|
||||
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
|
||||
|
||||
let ServerName::Dns(server_name) = config.server_name();
|
||||
@@ -202,95 +198,160 @@ impl Prover<state::CommitAccepted> {
|
||||
)
|
||||
.map_err(ProverError::config)?;
|
||||
|
||||
let (conn, conn_fut) = bind_client(socket, client);
|
||||
let span = self.span.clone();
|
||||
|
||||
let fut = Box::pin({
|
||||
let span = self.span.clone();
|
||||
let mpc_ctrl = mpc_ctrl.clone();
|
||||
async move {
|
||||
let conn_fut = async {
|
||||
mux_fut
|
||||
.poll_with(conn_fut.map_err(ProverError::from))
|
||||
.await?;
|
||||
let mpc_tls = MpcTlsClient::new(
|
||||
Box::new(mpc_fut.map_err(ProverError::from)),
|
||||
keys,
|
||||
vm,
|
||||
span,
|
||||
mpc_ctrl,
|
||||
client,
|
||||
decrypt,
|
||||
);
|
||||
|
||||
mpc_ctrl.stop().await?;
|
||||
|
||||
Ok::<_, ProverError>(())
|
||||
};
|
||||
|
||||
info!("starting MPC-TLS");
|
||||
|
||||
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
|
||||
conn_fut,
|
||||
mpc_fut.in_current_span().map_err(ProverError::from)
|
||||
)?;
|
||||
|
||||
info!("finished MPC-TLS");
|
||||
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
debug!("finalizing mpc");
|
||||
|
||||
// Finalize DEAP.
|
||||
mux_fut
|
||||
.poll_with(vm.finalize(&mut ctx))
|
||||
.await
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
}
|
||||
|
||||
// Pull out ZK VM.
|
||||
let (_, mut vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut vm,
|
||||
(keys.server_write_key, keys.server_write_iv),
|
||||
keys.server_write_mac_key,
|
||||
*tls_transcript.version(),
|
||||
tls_transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: config.server_name().clone(),
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
})
|
||||
}
|
||||
.instrument(span)
|
||||
});
|
||||
|
||||
Ok((
|
||||
conn,
|
||||
ProverFuture {
|
||||
fut,
|
||||
ctrl: ProverControl { mpc_ctrl },
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Connected {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
server_name: config.server_name().clone(),
|
||||
tls_client: Box::new(mpc_tls),
|
||||
output: None,
|
||||
},
|
||||
))
|
||||
};
|
||||
Ok(prover)
|
||||
}
|
||||
}
|
||||
|
||||
impl Prover<state::Connected> {
|
||||
/// Returns `true` if the prover wants to read TLS data from the server.
|
||||
pub fn wants_read_tls(&self) -> bool {
|
||||
self.state.tls_client.wants_read_tls()
|
||||
}
|
||||
|
||||
/// Returns `true` if the prover wants to write TLS data to the server.
|
||||
pub fn wants_write_tls(&self) -> bool {
|
||||
self.state.tls_client.wants_write_tls()
|
||||
}
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to read the TLS data from.
|
||||
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.read_tls(buf)
|
||||
}
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to write the TLS data to.
|
||||
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.write_tls(buf)
|
||||
}
|
||||
|
||||
/// Returns `true` if the prover wants to read plaintext data.
|
||||
pub fn wants_read(&self) -> bool {
|
||||
self.state.tls_client.wants_read()
|
||||
}
|
||||
|
||||
/// Returns `true` if the prover wants to write plaintext data.
|
||||
pub fn wants_write(&self) -> bool {
|
||||
self.state.tls_client.wants_write()
|
||||
}
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer where the plaintext data gets written to.
|
||||
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.read(buf)
|
||||
}
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to read the plaintext data from.
|
||||
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.write(buf)
|
||||
}
|
||||
|
||||
/// Closes the connection from the client side.
|
||||
pub fn client_close(&mut self) -> Result<(), ProverError> {
|
||||
self.state.tls_client.client_close()
|
||||
}
|
||||
|
||||
/// Closes the connection from the server side.
|
||||
pub fn server_close(&mut self) -> Result<(), ProverError> {
|
||||
self.state.tls_client.server_close()
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of data from the server until the
|
||||
/// server has closed the connection.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
|
||||
self.state.tls_client.enable_decryption(enable)
|
||||
}
|
||||
|
||||
/// Returns `true` if decryption of TLS traffic from the server is active.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.state.tls_client.is_decrypting()
|
||||
}
|
||||
|
||||
/// Polls the prover to make progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cx` - The async context.
|
||||
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
|
||||
let _ = self.state.mux_fut.poll_unpin(cx)?;
|
||||
|
||||
match self.state.tls_client.poll(cx)? {
|
||||
Poll::Ready(output) => {
|
||||
self.state.output = Some(output);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a committed prover after the TLS session has completed.
|
||||
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
|
||||
let TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
} = self.state.output.ok_or(ProverError::state(
|
||||
"prover has not yet closed the connection",
|
||||
))?;
|
||||
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: self.state.server_name,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(prover)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,29 +440,3 @@ impl Prover<state::Committed> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A controller for the prover.
|
||||
#[derive(Clone)]
|
||||
pub struct ProverControl {
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Defers decryption of data from the server until the server has closed
|
||||
/// the connection.
|
||||
///
|
||||
/// This is a performance optimization which will significantly reduce the
|
||||
/// amount of upload bandwidth used by the prover.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// * The prover may need to close the connection to the server in order for
|
||||
/// it to close the connection on its end. If neither the prover or server
|
||||
/// close the connection this will cause a deadlock.
|
||||
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
|
||||
self.mpc_ctrl
|
||||
.defer_decryption()
|
||||
.await
|
||||
.map_err(ProverError::from)
|
||||
}
|
||||
}
|
||||
|
||||
63
crates/tlsn/src/prover/client.rs
Normal file
63
crates/tlsn/src/prover/client.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
//! Provides a TLS client.
|
||||
|
||||
use crate::mpz::ProverZk;
|
||||
use mpc_tls::SessionKeys;
|
||||
use std::task::{Context, Poll};
|
||||
use tlsn_core::transcript::{TlsTranscript, Transcript};
|
||||
|
||||
mod mpc;
|
||||
|
||||
pub(crate) use mpc::MpcTlsClient;
|
||||
|
||||
/// TLS client for MPC and proxy-based TLS implementations.
|
||||
pub(crate) trait TlsClient {
|
||||
type Error: std::error::Error + Send + Sync + Unpin + 'static;
|
||||
|
||||
/// Returns `true` if the client wants to read TLS data from the server.
|
||||
fn wants_read_tls(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write TLS data to the server.
|
||||
fn wants_write_tls(&self) -> bool;
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Returns `true` if the client wants to read plaintext data.
|
||||
fn wants_read(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write plaintext data.
|
||||
fn wants_write(&self) -> bool;
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Client closes the connection.
|
||||
fn client_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Server closes the connection.
|
||||
fn server_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Enables or disables decryption of TLS traffic sent by the server.
|
||||
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns `true` if decryption of TLS traffic from the server is active.
|
||||
fn is_decrypting(&self) -> bool;
|
||||
|
||||
/// Polls the client to make progress.
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
|
||||
}
|
||||
|
||||
/// Output of a TLS session.
|
||||
pub(crate) struct TlsOutput {
|
||||
pub(crate) ctx: mpz_common::Context,
|
||||
pub(crate) vm: ProverZk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript: Transcript,
|
||||
}
|
||||
477
crates/tlsn/src/prover/client/mpc.rs
Normal file
477
crates/tlsn/src/prover/client/mpc.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
//! Implementation of an MPC-TLS client.
|
||||
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{TlsClient, TlsOutput},
|
||||
},
|
||||
tag::verify_tags,
|
||||
};
|
||||
use futures::{Future, FutureExt};
|
||||
use mpc_tls::{LeaderCtrl, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_vm_core::Execute;
|
||||
use std::{pin::Pin, sync::Arc, task::Poll};
|
||||
use tls_client::ClientConnection;
|
||||
use tlsn_core::transcript::TlsTranscript;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{Span, debug, instrument, trace, warn};
|
||||
|
||||
pub(crate) type MpcFuture = Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>>>;
|
||||
|
||||
type FinalizeFuture =
|
||||
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>>>;
|
||||
|
||||
pub(crate) struct MpcTlsClient {
|
||||
state: State,
|
||||
decrypt: bool,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Start {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
Active {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
Busy {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
|
||||
},
|
||||
ClientClose {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
|
||||
},
|
||||
ServerClose {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
|
||||
},
|
||||
Closing {
|
||||
ctx: Context,
|
||||
transcript: Box<TlsTranscript>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>>>>,
|
||||
},
|
||||
Finalizing {
|
||||
fut: Pin<FinalizeFuture>,
|
||||
},
|
||||
Finished,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl MpcTlsClient {
|
||||
pub(crate) fn new(
|
||||
mpc: MpcFuture,
|
||||
keys: SessionKeys,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
span: Span,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
tls: ClientConnection,
|
||||
decrypt: bool,
|
||||
) -> Self {
|
||||
let inner = InnerState {
|
||||
span,
|
||||
tls,
|
||||
vm,
|
||||
keys,
|
||||
mpc_ctrl,
|
||||
closed: false,
|
||||
};
|
||||
|
||||
Self {
|
||||
decrypt,
|
||||
state: State::Start {
|
||||
mpc: Box::into_pin(mpc),
|
||||
inner: Box::new(inner),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
|
||||
if let State::Active { inner, .. } = &mut self.state {
|
||||
Some(&mut inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client(&self) -> Option<&ClientConnection> {
|
||||
if let State::Active { inner, .. } = &self.state {
|
||||
Some(&inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsClient for MpcTlsClient {
|
||||
type Error = ProverError;
|
||||
|
||||
fn wants_read_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_read()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_write()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_read()
|
||||
{
|
||||
client.read_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_write()
|
||||
{
|
||||
client.write_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_read(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.sendable_plaintext_is_full()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.plaintext_is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.sendable_plaintext_is_full()
|
||||
{
|
||||
client.read_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.plaintext_is_empty()
|
||||
{
|
||||
client.write_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn client_close(&mut self) -> Result<(), Self::Error> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Active { inner, mpc } => {
|
||||
self.state = State::ClientClose {
|
||||
mpc,
|
||||
fut: Box::pin(inner.client_close()),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
other => {
|
||||
self.state = other;
|
||||
Err(ProverError::state(
|
||||
"unable to close connection, client is not in active state",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn server_close(&mut self) -> Result<(), Self::Error> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Active { inner, mpc } => {
|
||||
self.state = State::ServerClose {
|
||||
mpc,
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
other => {
|
||||
self.state = other;
|
||||
Err(ProverError::state(
|
||||
"unable to close connection, client is not in active state",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Active { inner, mpc } => {
|
||||
self.decrypt = enable;
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.set_decrypt(enable)),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
other => {
|
||||
self.state = other;
|
||||
Err(ProverError::state(
|
||||
"unable to enable decryption, client is not in active state",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_decrypting(&self) -> bool {
|
||||
self.decrypt
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Start { mpc, inner } => {
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.start()),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Active { mpc, inner } => {
|
||||
trace!("inner client is active");
|
||||
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.run()),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Busy { mut mpc, mut fut } => {
|
||||
trace!("inner client is busy");
|
||||
|
||||
let mpc_poll = mpc.as_mut().poll(cx)?;
|
||||
|
||||
assert!(
|
||||
matches!(mpc_poll, Poll::Pending),
|
||||
"mpc future should not be finished here"
|
||||
);
|
||||
|
||||
match fut.as_mut().poll(cx)? {
|
||||
Poll::Ready(inner) => {
|
||||
self.state = State::Active { mpc, inner };
|
||||
}
|
||||
Poll::Pending => self.state = State::Busy { mpc, fut },
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
State::ClientClose { mut mpc, mut fut } => {
|
||||
debug!("attempting to close connection clientside");
|
||||
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
|
||||
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, transcript)),
|
||||
};
|
||||
}
|
||||
(Poll::Ready(inner), Poll::Pending) => {
|
||||
self.state = State::ClientClose {
|
||||
mpc,
|
||||
fut: Box::pin(inner.client_close()),
|
||||
};
|
||||
}
|
||||
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Closing {
|
||||
ctx,
|
||||
transcript: Box::new(transcript),
|
||||
fut,
|
||||
};
|
||||
}
|
||||
(Poll::Pending, Poll::Pending) => self.state = State::ClientClose { mpc, fut },
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::ServerClose { mut mpc, mut fut } => {
|
||||
debug!("attempting to close connection serverside");
|
||||
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
|
||||
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, transcript)),
|
||||
};
|
||||
}
|
||||
(Poll::Ready(inner), Poll::Pending) => {
|
||||
self.state = State::ServerClose {
|
||||
mpc,
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
}
|
||||
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Closing {
|
||||
ctx,
|
||||
transcript: Box::new(transcript),
|
||||
fut,
|
||||
};
|
||||
}
|
||||
(Poll::Pending, Poll::Pending) => self.state = State::ServerClose { mpc, fut },
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Closing {
|
||||
ctx,
|
||||
transcript,
|
||||
mut fut,
|
||||
} => {
|
||||
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, *transcript)),
|
||||
};
|
||||
} else {
|
||||
self.state = State::Closing {
|
||||
ctx,
|
||||
transcript,
|
||||
fut,
|
||||
};
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
|
||||
Poll::Ready(output) => {
|
||||
let (inner, ctx, tls_transcript) = output?;
|
||||
let InnerState { vm, keys, .. } = inner;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
let (_, vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
let output = TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
};
|
||||
|
||||
self.state = State::Finished;
|
||||
Poll::Ready(Ok(output))
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Finalizing { fut };
|
||||
self.poll(cx)
|
||||
}
|
||||
},
|
||||
State::Finished => Poll::Ready(Err(ProverError::state(
|
||||
"mpc tls client polled again in finished state",
|
||||
))),
|
||||
State::Error => {
|
||||
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerState {
|
||||
span: Span,
|
||||
tls: ClientConnection,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
keys: SessionKeys,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
impl InnerState {
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.start().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
|
||||
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
|
||||
self.mpc_ctrl.enable_decryption(enable).await?;
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
|
||||
if let Err(e) = self.tls.send_close_notify().await {
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
};
|
||||
|
||||
self.mpc_ctrl.stop().await?;
|
||||
self.closed = true;
|
||||
debug!("closed connection");
|
||||
}
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed {
|
||||
self.tls.server_closed().await?;
|
||||
|
||||
self.mpc_ctrl.stop().await?;
|
||||
self.closed = true;
|
||||
debug!("closed connection");
|
||||
}
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn finalize(
|
||||
self,
|
||||
mut ctx: Context,
|
||||
transcript: TlsTranscript,
|
||||
) -> Result<(Self, Context, TlsTranscript), ProverError> {
|
||||
{
|
||||
let mut vm = self.vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
// Finalize DEAP.
|
||||
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
|
||||
// Pull out ZK VM.
|
||||
let mut zk = vm.zk();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut *zk,
|
||||
(self.keys.server_write_key, self.keys.server_write_iv),
|
||||
self.keys.server_write_mac_key,
|
||||
*transcript.version(),
|
||||
transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
debug!("verified tags from server");
|
||||
|
||||
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
|
||||
}
|
||||
|
||||
debug!("MPC-TLS done");
|
||||
Ok((self, ctx, transcript))
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,13 @@ impl ProverError {
|
||||
{
|
||||
Self::new(ErrorKind::Commit, source)
|
||||
}
|
||||
|
||||
pub(crate) fn state<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::State, source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -58,6 +65,7 @@ enum ErrorKind {
|
||||
Zk,
|
||||
Config,
|
||||
Commit,
|
||||
State,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProverError {
|
||||
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::State => f.write_str("state error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
@@ -86,8 +95,8 @@ impl From<std::io::Error> for ProverError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tls_client_async::ConnectionError> for ProverError {
|
||||
fn from(e: tls_client_async::ConnectionError) -> Self {
|
||||
impl From<tls_client::Error> for ProverError {
|
||||
fn from(e: tls_client::Error) -> Self {
|
||||
Self::new(ErrorKind::Io, e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
//! This module collects futures which are used by the [Prover].
|
||||
|
||||
use super::{Prover, ProverControl, ProverError, state};
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Prover future which must be polled for the TLS connection to make progress.
|
||||
pub struct ProverFuture {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fut: Pin<
|
||||
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
|
||||
>,
|
||||
pub(crate) ctrl: ProverControl,
|
||||
}
|
||||
|
||||
impl ProverFuture {
|
||||
/// Returns a controller for the prover for advanced functionality.
|
||||
pub fn control(&self) -> ProverControl {
|
||||
self.ctrl.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ProverFuture {
|
||||
type Output = Result<Prover<state::Committed>, ProverError>;
|
||||
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
self.fut.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,10 @@ use tokio::sync::Mutex;
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
mux::{MuxControl, MuxFuture},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{TlsClient, TlsOutput},
|
||||
},
|
||||
};
|
||||
|
||||
/// Entry state
|
||||
@@ -33,6 +37,17 @@ pub struct CommitAccepted {
|
||||
|
||||
opaque_debug::implement!(CommitAccepted);
|
||||
|
||||
/// State during the MPC-TLS connection.
|
||||
pub struct Connected {
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) server_name: ServerName,
|
||||
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError>>,
|
||||
pub(crate) output: Option<TlsOutput>,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Connected);
|
||||
|
||||
/// State after the TLS transcript has been committed.
|
||||
pub struct Committed {
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
@@ -52,11 +67,13 @@ pub trait ProverState: sealed::Sealed {}
|
||||
|
||||
impl ProverState for Initialized {}
|
||||
impl ProverState for CommitAccepted {}
|
||||
impl ProverState for Connected {}
|
||||
impl ProverState for Committed {}
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for super::Initialized {}
|
||||
impl Sealed for super::CommitAccepted {}
|
||||
impl Sealed for super::Connected {}
|
||||
impl Sealed for super::Committed {}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user