Compare commits

..

2 Commits

Author SHA1 Message Date
sinu
a780ac1136 wip 2026-01-13 11:06:12 -08:00
sinu
fc719c960f wip 2026-01-13 10:12:05 -08:00
16 changed files with 931 additions and 818 deletions

728
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -67,21 +67,21 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bc1d4ad" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
rangeset = { version = "0.4" }
serio = { version = "0.2" }

View File

@@ -123,8 +123,8 @@ fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let ctx_a = mt_a.new_context().unwrap();
let ctx_b = mt_b.new_context().unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));

View File

@@ -1,21 +0,0 @@
//! Execution context.
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub(crate) const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}

87
crates/tlsn/src/error.rs Normal file
View File

@@ -0,0 +1,87 @@
use std::fmt::Display;
/// Crate-level error.
#[derive(Debug, thiserror::Error)]
pub struct Error {
kind: ErrorKind,
msg: Option<String>,
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl Error {
pub(crate) fn io() -> Self {
Self {
kind: ErrorKind::Internal,
msg: None,
source: None,
}
}
pub(crate) fn internal() -> Self {
Self {
kind: ErrorKind::Internal,
msg: None,
source: None,
}
}
pub(crate) fn with_msg(mut self, msg: impl Into<String>) -> Self {
self.msg = Some(msg.into());
self
}
pub(crate) fn with_source<T>(mut self, source: T) -> Self
where
T: Into<Box<dyn std::error::Error + Send + Sync>>,
{
self.source = Some(source.into());
self
}
/// Returns `true` if the error was user created.
pub fn is_user(&self) -> bool {
todo!()
}
/// Returns `true` if the error originated from an IO error.
pub fn is_io(&self) -> bool {
self.kind.is_io()
}
/// Returns `true` if the error originated from an internal bug.
pub fn is_internal(&self) -> bool {
self.kind.is_internal()
}
/// Returns the error message if available.
pub fn msg(&self) -> Option<&str> {
todo!()
}
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
todo!()
}
}
#[derive(Debug)]
enum ErrorKind {
User,
Io,
Internal,
}
impl ErrorKind {
fn is_user(&self) -> bool {
matches!(self, ErrorKind::User)
}
fn is_io(&self) -> bool {
matches!(self, ErrorKind::Io)
}
fn is_internal(&self) -> bool {
matches!(self, ErrorKind::Internal)
}
}

View File

@@ -4,20 +4,25 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub(crate) mod context;
mod error;
pub(crate) mod ghash;
pub(crate) mod map;
pub(crate) mod mpz;
pub(crate) mod msg;
pub(crate) mod mux;
pub mod prover;
mod session;
pub(crate) mod tag;
pub(crate) mod transcript_internal;
pub mod verifier;
pub use error::Error;
pub use session::Session;
pub use tlsn_attestation as attestation;
pub use tlsn_core::{config, connection, hash, transcript, webpki};
/// Result type.
pub type Result<T, E = Error> = core::result::Result<T, E>;
use std::sync::LazyLock;
use semver::Version;

View File

@@ -1,107 +0,0 @@
//! Multiplexer used in the TLSNotary protocol.
use futures::{
AsyncRead, AsyncWrite, Future,
future::{FusedFuture, FutureExt},
};
use mpz_common::{ThreadId, io::Io, mux::Mux};
use tlsn_mux::{Connection, Handle};
use tracing::error;
/// Multiplexer controller providing streams.
pub(crate) struct MuxControl {
handle: Handle,
}
impl Mux for MuxControl {
fn open(&self, id: ThreadId) -> Result<Io, std::io::Error> {
let stream = self
.handle
.new_stream(id.as_ref())
.map_err(std::io::Error::other)?;
let io = Io::from_io(stream);
Ok(io)
}
}
impl From<MuxControl> for Box<dyn Mux + Send> {
fn from(val: MuxControl) -> Self {
Box::new(val)
}
}
/// Multiplexer future which must be polled for the muxer to make progress.
#[derive(Debug)]
pub(crate) struct MuxFuture<T> {
conn: Connection<T>,
}
impl<T: AsyncRead + AsyncWrite + Unpin> MuxFuture<T> {
pub(crate) fn new(socket: T) -> Self {
let mut mux_config = tlsn_mux::Config::default();
mux_config.set_max_num_streams(36);
mux_config.set_keep_alive(true);
mux_config.set_close_sync(true);
let conn = tlsn_mux::Connection::new(socket, mux_config);
Self { conn }
}
pub(crate) fn handle(&self) -> Result<MuxControl, std::io::Error> {
let handle = self.conn.handle().map_err(std::io::Error::other)?;
Ok(MuxControl { handle })
}
pub(crate) fn close(&mut self) {
self.conn.close();
}
pub(crate) fn into_io(self) -> Result<T, std::io::Error> {
self.conn
.try_into_io()
.map_err(|_| std::io::Error::other("unable to return IO, connection is not closed"))
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> FusedFuture for MuxFuture<T> {
fn is_terminated(&self) -> bool {
self.conn.is_complete()
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> MuxFuture<T> {
/// Awaits a future, polling the muxer future concurrently.
pub(crate) async fn poll_with<F, R>(&mut self, fut: F) -> R
where
F: Future<Output = R>,
{
let mut fut = Box::pin(fut.fuse());
let mut mux = self;
// Poll the future concurrently with the muxer future.
// If the muxer returns an error, continue polling the future
// until it completes.
loop {
futures::select! {
res = fut => return res,
res = mux => if let Err(e) = res {
error!("mux error: {:?}", e);
},
}
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> Future for MuxFuture<T> {
type Output = Result<(), tlsn_mux::ConnectionError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.conn.poll(cx)
}
}

View File

@@ -7,13 +7,13 @@ pub mod state;
pub use error::ProverError;
pub use future::ProverFuture;
use mpz_common::Context;
pub use tlsn_core::ProverOutput;
use crate::{
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::MuxFuture,
prover::error::ErrorKind,
tag::verify_tags,
};
@@ -44,6 +44,7 @@ use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub struct Prover<T: state::ProverState = state::Initialized> {
config: ProverConfig,
span: Span,
ctx: Option<Context>,
state: T,
}
@@ -52,12 +53,14 @@ impl Prover<state::Initialized> {
///
/// # Arguments
///
/// * `ctx` - A thread context.
/// * `config` - The configuration for the prover.
pub fn new(config: ProverConfig) -> Self {
pub(crate) fn new(ctx: Context, config: ProverConfig) -> Self {
let span = info_span!("prover");
Self {
config,
span,
ctx: Some(ctx),
state: state::Initialized,
}
}
@@ -70,37 +73,30 @@ impl Prover<state::Initialized> {
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
pub async fn commit(
mut self,
config: TlsCommitConfig,
socket: S,
) -> Result<Prover<state::CommitAccepted<S>>, ProverError> {
let mut mux_fut = MuxFuture::new(socket);
let mux_ctrl = mux_fut.handle()?;
let mut mt = build_mt_context(mux_ctrl);
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| ProverError::new(ErrorKind::Io, "context was dropped"))?;
// Sends protocol configuration to verifier for compatibility check.
mux_fut
.poll_with(async {
ctx.io_mut()
.send(TlsCommitRequestMsg {
request: config.to_request(),
version: crate::VERSION.clone(),
})
.await?;
ctx.io_mut()
.expect_next::<Response>()
.await?
.result
.map_err(ProverError::from)
ctx.io_mut()
.send(TlsCommitRequestMsg {
request: config.to_request(),
version: crate::VERSION.clone(),
})
.await?;
ctx.io_mut()
.expect_next::<Response>()
.await?
.result
.map_err(ProverError::from)?;
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = config.protocol().clone() else {
unreachable!("only MPC TLS is supported");
};
@@ -115,27 +111,20 @@ impl Prover<state::Initialized> {
debug!("setting up mpc-tls");
mux_fut.poll_with(mpc_tls.preprocess()).await?;
mpc_tls.preprocess().await?;
debug!("mpc-tls setup complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mux_fut,
mpc_tls,
keys,
vm,
},
ctx: None,
state: state::CommitAccepted { mpc_tls, keys, vm },
})
}
}
impl<Io> Prover<state::CommitAccepted<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket.
///
/// Returns a handle to the TLS connection, a future which returns the
@@ -151,13 +140,9 @@ where
self,
config: TlsClientConfig,
socket: S,
) -> Result<(TlsConnection, ProverFuture<Io>), ProverError> {
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::CommitAccepted {
mut mux_fut,
mpc_tls,
keys,
vm,
..
mpc_tls, keys, vm, ..
} = self.state;
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
@@ -211,10 +196,7 @@ where
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
conn_fut.await.map_err(ProverError::from)?;
mpc_ctrl.stop().await?;
Ok::<_, ProverError>(())
@@ -235,10 +217,7 @@ where
debug!("finalizing mpc");
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
@@ -260,9 +239,7 @@ where
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
vm.execute_all(&mut ctx).await.map_err(ProverError::zk)?;
let transcript = tls_transcript
.to_transcript()
@@ -271,9 +248,8 @@ where
Ok(Prover {
config: self.config,
span: self.span,
ctx: Some(ctx),
state: state::Committed {
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
@@ -295,10 +271,7 @@ where
}
}
impl<Io> Prover<state::Committed<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Prover<state::Committed> {
/// Returns the TLS transcript.
pub fn tls_transcript(&self) -> &TlsTranscript {
&self.state.tls_transcript
@@ -316,9 +289,11 @@ where
/// * `config` - The disclosure configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
let ctx = self
.ctx
.as_mut()
.ok_or_else(|| ProverError::new(ErrorKind::Io, "context was dropped"))?;
let state::Committed {
mux_fut,
ctx,
vm,
keys,
server_name,
@@ -354,27 +329,18 @@ where
transcript: partial_transcript,
};
let output = mux_fut
.poll_with(async {
ctx.io_mut().send(msg).await.map_err(ProverError::from)?;
ctx.io_mut().send(msg).await.map_err(ProverError::from)?;
ctx.io_mut().expect_next::<Response>().await?.result?;
ctx.io_mut().expect_next::<Response>().await?.result?;
prove::prove(ctx, vm, keys, transcript, tls_transcript, config).await
})
.await?;
let output = prove::prove(ctx, vm, keys, transcript, tls_transcript, config).await?;
Ok(output)
}
/// Closes the connection with the verifier.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(mut self) -> Result<Io, ProverError> {
let mux_fut = &mut self.state.mux_fut;
mux_fut.close();
mux_fut.await?;
self.state.mux_fut.into_io().map_err(ProverError::from)
pub async fn close(self) -> Result<(), ProverError> {
Ok(())
}
}

View File

@@ -10,7 +10,7 @@ pub struct ProverError {
}
impl ProverError {
fn new<E>(kind: ErrorKind, source: E) -> Self
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
@@ -50,7 +50,7 @@ impl ProverError {
}
#[derive(Debug)]
enum ErrorKind {
pub(crate) enum ErrorKind {
Io,
Mpc,
Zk,

View File

@@ -5,25 +5,23 @@ use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture<Io> {
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<
dyn Future<Output = Result<Prover<state::Committed<Io>>, ProverError>> + Send + 'static,
>,
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl<Io> ProverFuture<Io> {
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl<Io> Future for ProverFuture<Io> {
type Output = Result<Prover<state::Committed<Io>>, ProverError>;
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,

View File

@@ -3,7 +3,6 @@
use std::sync::Arc;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
connection::ServerName,
transcript::{TlsTranscript, Transcript},
@@ -11,10 +10,7 @@ use tlsn_core::{
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::MuxFuture,
};
use crate::mpz::{ProverMpc, ProverZk};
/// Entry state
pub struct Initialized;
@@ -23,19 +19,16 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub struct CommitAccepted {
pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
}
opaque_debug::implement!(CommitAccepted<Io>);
opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub(crate) ctx: Context,
pub struct Committed {
pub(crate) vm: ProverZk,
pub(crate) server_name: ServerName,
pub(crate) keys: SessionKeys,
@@ -43,18 +36,18 @@ pub struct Committed<Io> {
pub(crate) transcript: Transcript,
}
opaque_debug::implement!(Committed<Io>);
opaque_debug::implement!(Committed);
#[allow(missing_docs)]
pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl<Io> ProverState for CommitAccepted<Io> {}
impl<Io> ProverState for Committed<Io> {}
impl ProverState for CommitAccepted {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl<Io> Sealed for super::CommitAccepted<Io> {}
impl<Io> Sealed for super::Committed<Io> {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {}
}

293
crates/tlsn/src/session.rs Normal file
View File

@@ -0,0 +1,293 @@
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite};
use mpz_common::{ThreadId, context::Multithread, io::Io, mux::Mux};
use tlsn_core::config::{prover::ProverConfig, verifier::VerifierConfig};
use tlsn_mux::{Connection, Handle};
use crate::{
Error, Result,
prover::{Prover, state as prover_state},
verifier::{Verifier, state as verifier_state},
};
/// Maximum concurrency for multi-threaded context.
const MAX_CONCURRENCY: usize = 8;
/// Session state.
#[must_use = "session must be polled continuously to make progress, including during closing."]
pub struct Session<Io> {
conn: Option<Connection<Io>>,
mt: Multithread,
}
impl<Io> Session<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
/// Creates a new session.
pub fn new(io: Io) -> Self {
let mut mux_config = tlsn_mux::Config::default();
mux_config.set_max_num_streams(36);
mux_config.set_keep_alive(true);
mux_config.set_close_sync(true);
let conn = tlsn_mux::Connection::new(io, mux_config);
let handle = conn.handle().expect("handle should be available");
let mt = build_mt_context(MuxHandle { handle: handle });
Self {
conn: Some(conn),
mt,
}
}
/// Creates a new prover.
pub fn new_prover(
&mut self,
config: ProverConfig,
) -> Result<Prover<prover_state::Initialized>> {
let ctx = self.mt.new_context().map_err(|e| {
Error::internal()
.with_msg("failed to created new prover")
.with_source(e)
})?;
Ok(Prover::new(ctx, config))
}
/// Creates a new verifier.
pub fn new_verifier(
&mut self,
config: VerifierConfig,
) -> Result<Verifier<verifier_state::Initialized>> {
let ctx = self.mt.new_context().map_err(|e| {
Error::internal()
.with_msg("failed to created new verifier")
.with_source(e)
})?;
Ok(Verifier::new(ctx, config))
}
/// Returns `true` if the session is closed.
pub fn is_closed(&self) -> bool {
self.conn
.as_ref()
.map(|mux| mux.is_complete())
.unwrap_or_default()
}
/// Closes the session.
///
/// This will cause the session to begin closing. Session must continue to be polled until completion.
pub fn close(&mut self) {
self.conn.as_mut().map(|conn| conn.close());
}
/// Attempts to take the IO, returning an error if it is not available.
pub fn try_take(&mut self) -> Result<Io> {
let conn = self.conn.take().ok_or_else(|| {
Error::io().with_msg("failed to take the session io, it was already taken")
})?;
match conn.try_into_io() {
Err(conn) => {
self.conn = Some(conn);
Err(Error::io()
.with_msg("failed to take the session io, session was not completed yet"))
}
Ok(conn) => Ok(conn),
}
}
/// Polls the session.
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.conn
.as_mut()
.ok_or_else(|| {
Error::io()
.with_msg("failed to poll the session connection because it has been taken")
})?
.poll(cx)
.map_err(|e| {
Error::io()
.with_msg("error occurred while polling the session connection")
.with_source(e)
})
}
/// Splits the session into a driver and handle.
///
/// The driver must be polled to make progress. The handle is used
/// for creating provers/verifiers and closing the session.
pub fn split(self) -> (SessionDriver<Io>, SessionHandle) {
let should_close = Arc::new(AtomicBool::new(false));
(
SessionDriver {
conn: self.conn,
should_close: should_close.clone(),
},
SessionHandle {
mt: self.mt,
should_close,
},
)
}
}
impl<Io> Future for Session<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Session::poll(&mut (*self), cx)
}
}
/// The polling half of a split session.
///
/// Must be polled continuously to drive the session. Returns the underlying
/// IO when the session closes.
#[must_use = "driver must be polled to make progress"]
pub struct SessionDriver<Io> {
conn: Option<Connection<Io>>,
should_close: Arc<AtomicBool>,
}
impl<Io> SessionDriver<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
/// Polls the driver.
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Io>> {
let conn = self.conn.as_mut().ok_or_else(|| {
Error::io().with_msg("session driver already completed")
})?;
if self.should_close.load(Ordering::Acquire) {
conn.close();
}
match conn.poll(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(Error::io()
.with_msg("error polling session connection")
.with_source(e)));
}
Poll::Pending => return Poll::Pending,
}
let conn = self.conn.take().unwrap();
Poll::Ready(conn.try_into_io().map_err(|_| {
Error::io().with_msg("failed to take session io")
}))
}
}
impl<Io> Future for SessionDriver<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<Io>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
SessionDriver::poll(&mut *self, cx)
}
}
/// The control half of a split session.
///
/// Used to create provers/verifiers and control the session lifecycle.
pub struct SessionHandle {
mt: Multithread,
should_close: Arc<AtomicBool>,
}
impl SessionHandle {
/// Creates a new prover.
pub fn new_prover(
&mut self,
config: ProverConfig,
) -> Result<Prover<prover_state::Initialized>> {
let ctx = self.mt.new_context().map_err(|e| {
Error::internal()
.with_msg("failed to create new prover")
.with_source(e)
})?;
Ok(Prover::new(ctx, config))
}
/// Creates a new verifier.
pub fn new_verifier(
&mut self,
config: VerifierConfig,
) -> Result<Verifier<verifier_state::Initialized>> {
let ctx = self.mt.new_context().map_err(|e| {
Error::internal()
.with_msg("failed to create new verifier")
.with_source(e)
})?;
Ok(Verifier::new(ctx, config))
}
/// Signals the session to close.
///
/// The driver must continue to be polled until it completes.
pub fn close(&self) {
self.should_close.store(true, Ordering::Release);
}
}
/// Multiplexer controller providing streams.
struct MuxHandle {
handle: Handle,
}
impl std::fmt::Debug for MuxHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MuxHandle").finish_non_exhaustive()
}
}
impl Mux for MuxHandle {
fn open(&self, id: ThreadId) -> Result<Io, std::io::Error> {
let stream = self
.handle
.new_stream(id.as_ref())
.map_err(std::io::Error::other)?;
let io = Io::from_io(stream);
Ok(io)
}
}
/// Builds a multi-threaded context with the given muxer.
fn build_mt_context(mux: MuxHandle) -> Multithread {
let builder = Multithread::builder()
.mux(Box::new(mux) as Box<_>)
.concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}

View File

@@ -7,16 +7,16 @@ mod verify;
use std::sync::Arc;
pub use error::VerifierError;
use mpz_common::Context;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::MuxFuture,
tag::verify_tags,
verifier::error::ErrorKind,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use futures::TryFutureExt;
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
@@ -44,16 +44,18 @@ pub struct SessionInfo {
pub struct Verifier<T: state::VerifierState = state::Initialized> {
config: VerifierConfig,
span: Span,
ctx: Option<Context>,
state: T,
}
impl Verifier<state::Initialized> {
/// Creates a new verifier.
pub fn new(config: VerifierConfig) -> Self {
pub(crate) fn new(ctx: Context, config: VerifierConfig) -> Self {
let span = info_span!("verifier");
Self {
config,
span,
ctx: Some(ctx),
state: state::Initialized,
}
}
@@ -62,36 +64,22 @@ impl Verifier<state::Initialized> {
///
/// This initiates the TLS commitment protocol, receiving the prover's
/// configuration and providing the opportunity to accept or reject it.
///
/// # Arguments
///
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
) -> Result<Verifier<state::CommitStart<S>>, VerifierError> {
let mut mux_fut = MuxFuture::new(socket);
let mux_ctrl = mux_fut.handle()?;
let mut mt = build_mt_context(mux_ctrl);
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
pub async fn commit(mut self) -> Result<Verifier<state::CommitStart>, VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
// Receives protocol configuration from prover to perform compatibility check.
let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
let TlsCommitRequestMsg { request, version } = ctx.io_mut().expect_next().await?;
if version != *crate::VERSION {
let msg = format!(
"prover version does not match with verifier: {version} != {}",
*crate::VERSION
);
mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
.await?;
mux_fut.close();
mux_fut.await?;
ctx.io_mut().send(Response::err(Some(msg.clone()))).await?;
return Err(VerifierError::config(msg));
}
@@ -99,19 +87,13 @@ impl Verifier<state::Initialized> {
Ok(Verifier {
config: self.config,
span: self.span,
state: state::CommitStart {
mux_fut,
ctx,
request,
},
ctx: Some(ctx),
state: state::CommitStart { request },
})
}
}
impl<Io> Verifier<state::CommitStart<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Verifier<state::CommitStart> {
/// Returns the TLS commitment request.
pub fn request(&self) -> &TlsCommitRequest {
&self.state.request
@@ -119,14 +101,14 @@ where
/// Accepts the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted<Io>>, VerifierError> {
let state::CommitStart {
mut mux_fut,
mut ctx,
request,
} = self.state;
pub async fn accept(mut self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
let state::CommitStart { request } = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
ctx.io_mut().send(Response::ok()).await?;
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol().clone() else {
unreachable!("only MPC TLS is supported");
@@ -142,59 +124,41 @@ where
debug!("setting up mpc-tls");
mux_fut.poll_with(mpc_tls.preprocess()).await?;
mpc_tls.preprocess().await?;
debug!("mpc-tls setup complete");
Ok(Verifier {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mux_fut,
mpc_tls,
keys,
vm,
},
ctx: None,
state: state::CommitAccepted { mpc_tls, keys, vm },
})
}
/// Rejects the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> {
let state::CommitStart {
mut mux_fut,
mut ctx,
..
} = self.state;
pub async fn reject(mut self, msg: Option<&str>) -> Result<(), VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
mux_fut
.poll_with(ctx.io_mut().send(Response::err(msg)))
.await?;
mux_fut.close();
mux_fut.await?;
ctx.io_mut().send(Response::err(msg)).await?;
Ok(())
}
}
impl<Io> Verifier<state::CommitAccepted<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Verifier<state::CommitAccepted> {
/// Runs the verifier until the TLS connection is closed.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed<Io>>, VerifierError> {
let state::CommitAccepted {
mut mux_fut,
mpc_tls,
vm,
keys,
} = self.state;
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted { mpc_tls, vm, keys } = self.state;
info!("starting MPC-TLS");
let (mut ctx, tls_transcript) = mux_fut.poll_with(mpc_tls.run()).await?;
let (mut ctx, tls_transcript) = mpc_tls.run().await?;
info!("finished MPC-TLS");
@@ -203,10 +167,7 @@ where
debug!("finalizing mpc");
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(VerifierError::mpc)?;
vm.finalize(&mut ctx).await.map_err(VerifierError::mpc)?;
debug!("mpc finalized");
}
@@ -228,9 +189,7 @@ where
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
vm.execute_all(&mut ctx).map_err(VerifierError::zk).await?;
// Verify the tags.
// After the verification, the entire TLS trancript becomes
@@ -240,9 +199,8 @@ where
Ok(Verifier {
config: self.config,
span: self.span,
ctx: Some(ctx),
state: state::Committed {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
@@ -251,10 +209,7 @@ where
}
}
impl<Io> Verifier<state::Committed<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Verifier<state::Committed> {
/// Returns the TLS transcript.
pub fn tls_transcript(&self) -> &TlsTranscript {
&self.state.tls_transcript
@@ -262,10 +217,12 @@ where
/// Begins verification of statements from the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(self) -> Result<Verifier<state::Verify<Io>>, VerifierError> {
pub async fn verify(mut self) -> Result<Verifier<state::Verify>, VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
let state::Committed {
mut mux_fut,
mut ctx,
vm,
keys,
tls_transcript,
@@ -275,16 +232,17 @@ where
request,
handshake,
transcript,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
} = ctx
.io_mut()
.expect_next()
.map_err(VerifierError::from)
.await?;
Ok(Verifier {
config: self.config,
span: self.span,
ctx: Some(ctx),
state: state::Verify {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
@@ -297,19 +255,12 @@ where
/// Closes the connection with the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(mut self) -> Result<Io, VerifierError> {
let mux_fut = &mut self.state.mux_fut;
mux_fut.close();
mux_fut.await?;
self.state.mux_fut.into_io().map_err(VerifierError::from)
pub async fn close(self) -> Result<(), VerifierError> {
Ok(())
}
}
impl<Io> Verifier<state::Verify<Io>>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
impl Verifier<state::Verify> {
/// Returns the proving request.
pub fn request(&self) -> &ProveRequest {
&self.state.request
@@ -317,11 +268,13 @@ where
/// Accepts the proving request.
pub async fn accept(
self,
) -> Result<(VerifierOutput, Verifier<state::Committed<Io>>), VerifierError> {
mut self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
let state::Verify {
mut mux_fut,
mut ctx,
mut vm,
keys,
tls_transcript,
@@ -330,32 +283,30 @@ where
transcript,
} = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
ctx.io_mut().send(Response::ok()).await?;
let cert_verifier =
ServerCertVerifier::new(self.config.root_store()).map_err(VerifierError::config)?;
let output = mux_fut
.poll_with(verify::verify(
&mut ctx,
&mut vm,
&keys,
&cert_verifier,
&tls_transcript,
request,
handshake,
transcript,
))
.await?;
let output = verify::verify(
&mut ctx,
&mut vm,
&keys,
&cert_verifier,
&tls_transcript,
request,
handshake,
transcript,
)
.await?;
Ok((
output,
Verifier {
config: self.config,
span: self.span,
ctx: Some(ctx),
state: state::Committed {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
@@ -366,28 +317,27 @@ where
/// Rejects the proving request.
pub async fn reject(
self,
mut self,
msg: Option<&str>,
) -> Result<Verifier<state::Committed<Io>>, VerifierError> {
) -> Result<Verifier<state::Committed>, VerifierError> {
let mut ctx = self
.ctx
.take()
.ok_or_else(|| VerifierError::new(ErrorKind::Io, "context was dropped"))?;
let state::Verify {
mut mux_fut,
mut ctx,
vm,
keys,
tls_transcript,
..
} = self.state;
mux_fut
.poll_with(ctx.io_mut().send(Response::err(msg)))
.await?;
ctx.io_mut().send(Response::err(msg)).await?;
Ok(Verifier {
config: self.config,
span: self.span,
ctx: Some(ctx),
state: state::Committed {
mux_fut,
ctx,
vm,
keys,
tls_transcript,

View File

@@ -10,7 +10,7 @@ pub struct VerifierError {
}
impl VerifierError {
fn new<E>(kind: ErrorKind, source: E) -> Self
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
@@ -50,7 +50,7 @@ impl VerifierError {
}
#[derive(Debug)]
enum ErrorKind {
pub(crate) enum ErrorKind {
Io,
Config,
Mpc,

View File

@@ -2,9 +2,7 @@
use std::sync::Arc;
use crate::mux::MuxFuture;
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
connection::{HandshakeData, ServerName},
@@ -24,40 +22,33 @@ pub struct Initialized;
opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover.
pub struct CommitStart<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub(crate) ctx: Context,
pub struct CommitStart {
pub(crate) request: TlsCommitRequest,
}
opaque_debug::implement!(CommitStart<Io>);
opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub struct CommitAccepted {
pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
}
opaque_debug::implement!(CommitAccepted<Io>);
opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub(crate) ctx: Context,
pub struct Committed {
pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
}
opaque_debug::implement!(Committed<Io>);
opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify<Io> {
pub(crate) mux_fut: MuxFuture<Io>,
pub(crate) ctx: Context,
pub struct Verify {
pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
@@ -66,19 +57,19 @@ pub struct Verify<Io> {
pub(crate) transcript: Option<PartialTranscript>,
}
opaque_debug::implement!(Verify<Io>);
opaque_debug::implement!(Verify);
impl VerifierState for Initialized {}
impl<Io> VerifierState for CommitStart<Io> {}
impl<Io> VerifierState for CommitAccepted<Io> {}
impl<Io> VerifierState for Committed<Io> {}
impl<Io> VerifierState for Verify<Io> {}
impl VerifierState for CommitStart {}
impl VerifierState for CommitAccepted {}
impl VerifierState for Committed {}
impl VerifierState for Verify {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl<Io> Sealed for super::CommitStart<Io> {}
impl<Io> Sealed for super::CommitAccepted<Io> {}
impl<Io> Sealed for super::Committed<Io> {}
impl<Io> Sealed for super::Verify<Io> {}
impl Sealed for super::CommitStart {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {}
impl Sealed for super::Verify {}
}

View File

@@ -1,5 +1,6 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use tlsn::{
Session,
config::{
prove::ProveConfig,
prover::ProverConfig,
@@ -18,9 +19,7 @@ use tlsn_core::ProverOutput;
use tlsn_server_fixture::bind;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::instrument;
// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
@@ -37,9 +36,34 @@ async fn test() {
tracing_subscriber::fmt::init();
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
let mut session_p = Session::new(socket_0.compat());
let mut session_v = Session::new(socket_1.compat());
let prover = session_p
.new_prover(ProverConfig::builder().build().unwrap())
.unwrap();
let verifier = session_v
.new_verifier(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap(),
)
.unwrap();
let (session_p_driver, session_p_handle) = session_p.split();
let (session_v_driver, session_v_handle) = session_v.split();
tokio::spawn(session_p_driver);
tokio::spawn(session_v_driver);
let ((_full_transcript, _prover_output), verifier_output) =
tokio::join!(prover(socket_0), verifier(socket_1));
tokio::join!(run_prover(prover), run_verifier(verifier));
session_p_handle.close();
session_v_handle.close();
let partial_transcript = verifier_output.transcript.unwrap();
let ServerName::Dns(server_name) = verifier_output.server_name.unwrap();
@@ -56,15 +80,12 @@ async fn test() {
);
}
#[instrument(skip(verifier_socket))]
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
) -> (Transcript, ProverOutput) {
async fn run_prover(prover: Prover) -> (Transcript, ProverOutput) {
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind(server_socket.compat()));
let prover = Prover::new(ProverConfig::builder().build().unwrap())
let prover = prover
.commit(
TlsCommitConfig::builder()
.protocol(
@@ -78,7 +99,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
)
.build()
.unwrap(),
verifier_socket.compat(),
)
.await
.unwrap();
@@ -150,21 +170,9 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
(transcript, output)
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> VerifierOutput {
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap(),
);
async fn run_verifier(verifier: Verifier) -> VerifierOutput {
let verifier = verifier
.commit(socket.compat())
.commit()
.await
.unwrap()
.accept()