Compare commits

...

4 Commits

Author SHA1 Message Date
th4s
4727df6fd4 WIP: removing actor... 2025-12-08 14:26:48 +01:00
th4s
a0c7d469f6 feat(tlsn): implement sans-io api for prover and verifier 2025-12-08 12:23:05 +01:00
th4s
0bf4a857b9 refactor: sans-io TLS IO (#1036)
* refactor: add tls-client trait

* cleanup

* fix clippy

* add start state

* assert that mpc future is pending
2025-12-01 18:55:10 +01:00
th4s
4d449b2a1d refactor(tls-client): make tls_client compatible with a synchronous API (#1027)
* refactor(tls-client): make `write_plaintext` sync and remove async api

* restore `complete_io`

* do not potentially block in `write_all_plaintext`
2025-12-01 18:55:10 +01:00
26 changed files with 950 additions and 3158 deletions

70
Cargo.lock generated
View File

@@ -2930,7 +2930,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -3173,6 +3173,16 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "futures-plex"
version = "0.1.0"
source = "git+https://github.com/tlsnotary/tlsn-utils?rev=0b46dc0#0b46dc0b11229bd606c6262de0de5ac8f2b76f41"
dependencies = [
"bytes",
"futures-io",
"futures-util",
]
[[package]]
name = "futures-rustls"
version = "0.25.1"
@@ -3916,7 +3926,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -4147,25 +4157,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "ludi"
version = "0.1.0"
source = "git+https://github.com/sinui0/ludi?rev=e511c3b#e511c3b330dc298613cc3fd168244619e81ac740"
dependencies = [
"futures-util",
"ludi-core",
]
[[package]]
name = "ludi-core"
version = "0.1.0"
source = "git+https://github.com/sinui0/ludi?rev=e511c3b#e511c3b330dc298613cc3fd168244619e81ac740"
dependencies = [
"futures-channel",
"futures-core",
"futures-util",
]
[[package]]
name = "macro-string"
version = "0.1.4"
@@ -5041,7 +5032,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d8fae84b431384b68627d0f9b3b1245fcf9f46f6c0e3dc902e9dce64edd1967"
dependencies = [
"libc",
"windows-sys 0.61.2",
"windows-sys 0.48.0",
]
[[package]]
@@ -5615,7 +5606,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.60.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -6107,7 +6098,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6120,7 +6111,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.11.0",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -7010,7 +7001,7 @@ dependencies = [
"getrandom 0.3.3",
"once_cell",
"rustix 1.1.2",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -7178,6 +7169,7 @@ dependencies = [
"aes 0.8.4",
"ctr 0.9.2",
"futures",
"futures-plex",
"ghash 0.5.1",
"http-body-util",
"hyper",
@@ -7212,7 +7204,6 @@ dependencies = [
"tlsn-server-fixture",
"tlsn-server-fixture-certs",
"tlsn-tls-client",
"tlsn-tls-client-async",
"tlsn-tls-core",
"tokio",
"tokio-util",
@@ -7500,7 +7491,6 @@ dependencies = [
"futures",
"generic-array",
"ghash 0.5.1",
"ludi",
"mpz-common",
"mpz-core",
"mpz-fields",
@@ -7528,7 +7518,6 @@ dependencies = [
"tlsn-key-exchange",
"tlsn-tls-backend",
"tlsn-tls-client",
"tlsn-tls-client-async",
"tlsn-tls-core",
"tokio",
"tokio-util",
@@ -7603,26 +7592,6 @@ dependencies = [
"webpki-roots 1.0.3",
]
[[package]]
name = "tlsn-tls-client-async"
version = "0.1.0-alpha.14-pre"
dependencies = [
"bytes",
"futures",
"http-body-util",
"hyper",
"hyper-util",
"rstest",
"rustls-pki-types",
"rustls-webpki 0.103.7",
"thiserror 1.0.69",
"tls-server-fixture",
"tlsn-tls-client",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "tlsn-tls-core"
version = "0.1.0-alpha.14-pre"
@@ -7669,7 +7638,6 @@ dependencies = [
"tlsn",
"tlsn-core",
"tlsn-server-fixture-certs",
"tlsn-tls-client-async",
"tlsn-tls-core",
"tracing",
"tracing-subscriber",
@@ -8543,7 +8511,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.61.2",
"windows-sys 0.48.0",
]
[[package]]

View File

@@ -13,7 +13,6 @@ members = [
"crates/server-fixture/server",
"crates/tls/backend",
"crates/tls/client",
"crates/tls/client-async",
"crates/tls/core",
"crates/mpc-tls",
"crates/tls/server-fixture",
@@ -57,7 +56,6 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-harness-core = { path = "crates/harness/core" }
@@ -82,6 +80,7 @@ mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0b46dc0" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }

View File

@@ -34,7 +34,6 @@ mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true }
async-trait = { workspace = true }
@@ -66,7 +65,6 @@ rand_chacha = { workspace = true }
rstest = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }

View File

@@ -15,13 +15,6 @@ impl MpcTlsError {
Self(ErrorRepr::Peer(err.into()))
}
pub(crate) fn actor<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Actor(err.into()))
}
pub(crate) fn state<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
@@ -72,8 +65,6 @@ enum ErrorRepr {
Peer(Box<dyn std::error::Error + Send + Sync>),
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("actor error: {0}")]
Actor(Box<dyn std::error::Error + Send + Sync>),
#[error("state error: {0}")]
State(Box<dyn std::error::Error + Send + Sync>),
#[error("allocation error: {0}")]

View File

@@ -1,5 +1,3 @@
mod actor;
use crate::{
error::MpcTlsError,
msg::{
@@ -14,7 +12,6 @@ use async_trait::async_trait;
use hmac_sha256::{MpcPrf, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use ludi::Context as LudiContext;
use mpz_common::{Context, Flush};
use mpz_core::{bitvec::BitVec, Block};
use mpz_memory_core::DecodeFutureTyped;
@@ -50,13 +47,9 @@ use tlsn_core::{
};
use tracing::{debug, instrument, trace, warn};
/// Controller for MPC-TLS leader.
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
/// MPC-TLS leader.
#[derive(Debug)]
pub struct MpcTlsLeader {
self_handle: Option<LeaderCtrl>,
config: Config,
state: State,
@@ -114,7 +107,6 @@ impl MpcTlsLeader {
let is_decrypting = !config.defer_decryption;
Self {
self_handle: None,
config,
state: State::Init {
ctx,
@@ -378,18 +370,42 @@ impl MpcTlsLeader {
Ok(())
}
/// Defers decryption of any incoming messages.
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
#[instrument(level = "debug", skip_all, err)]
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
self.is_decrypting = false;
self.notifier.clear();
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = enable;
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
Ok(())
}
/// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop();
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Returns the context and transcript.
///
/// Should be called after a successful call to [`Backend::server_closed`].
pub fn finish(&mut self) -> Option<(Context, TlsTranscript)> {
match self.state.take() {
State::Closed {
ctx, transcript, ..
} => Some((ctx, transcript)),
state => {
self.state = state;
None
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ pub(crate) mod utils;
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
pub use error::MpcTlsError;
pub use follower::MpcTlsFollower;
pub use leader::{LeaderCtrl, MpcTlsLeader};
pub use leader::MpcTlsLeader;
use std::{future::Future, pin::Pin, sync::Arc};

View File

@@ -1,160 +0,0 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.max_recv(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
rcot_recv_a,
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b,
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
);
(leader, follower)
}

View File

@@ -1,39 +0,0 @@
[package]
name = "tlsn-tls-client-async"
authors = ["TLSNotary Team"]
description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "tls_client_async"
[features]
default = ["tracing"]
tracing = ["dep:tracing"]
[dependencies]
tlsn-tls-client = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tokio-util = { workspace = true, features = ["io", "compat"] }
tracing = { workspace = true, optional = true }
[dev-dependencies]
tls-server-fixture = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -1,89 +0,0 @@
use bytes::Bytes;
use futures::{
channel::mpsc::{Receiver, SendError, Sender},
sink::SinkMapErr,
AsyncRead, AsyncWrite, SinkExt,
};
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
pin::Pin,
task::{Context, Poll},
};
use tokio_util::{
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
io::{CopyToBytes, SinkWriter, StreamReader},
};
type CompatSinkWriter =
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
/// A TLS connection to a server.
///
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
#[derive(Debug)]
pub struct TlsConnection {
/// The data to be transmitted to the server is sent to this sink.
tx_sender: CompatSinkWriter,
/// The data to be received from the server is received from this stream.
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
}
impl TlsConnection {
/// Creates a new TLS connection.
pub(crate) fn new(
tx_sender: Sender<Bytes>,
rx_receiver: Receiver<Result<Bytes, IoError>>,
) -> Self {
fn convert_error(err: SendError) -> IoError {
if err.is_disconnected() {
IoErrorKind::BrokenPipe.into()
} else {
IoErrorKind::WouldBlock.into()
}
}
Self {
tx_sender: SinkWriter::new(CopyToBytes::new(
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
))
.compat_write(),
rx_receiver: StreamReader::new(rx_receiver).compat(),
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_close(cx)
}
}

View File

@@ -1,269 +0,0 @@
//! Provides a TLS client which exposes an async socket.
//!
//! This library provides the [bind_client] function which attaches a TLS client
//! to a socket connection and then exposes a [TlsConnection] object, which
//! provides an async socket API for reading and writing cleartext. The TLS
//! client will then automatically encrypt and decrypt traffic and forward that
//! to the provided socket.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod conn;
use bytes::{Buf, Bytes};
use futures::{
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, error, trace, warn, Instrument};
use tls_client::ClientConnection;
pub use conn::TlsConnection;
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
/// An error that can occur during a TLS connection.
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
TlsError(#[from] tls_client::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
}
/// Closed connection data.
#[derive(Debug)]
pub struct ClosedConnection {
/// The connection for the client
pub client: ClientConnection,
/// Sent plaintext bytes
pub sent: Vec<u8>,
/// Received plaintext bytes
pub recv: Vec<u8>,
}
/// A future which runs the TLS connection to completion.
///
/// This future must be polled in order for the connection to make progress.
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture {
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
}
impl Future for ConnectionFuture {
type Output = Result<ClosedConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
/// Binds a client connection to the provided socket.
///
/// Returns a connection handle and a future which runs the connection to
/// completion.
///
/// # Errors
///
/// Any connection errors that occur will be returned from the future, not
/// [`TlsConnection`].
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
socket: T,
mut client: ClientConnection,
) -> (TlsConnection, ConnectionFuture) {
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
let conn = TlsConnection::new(tx_sender, rx_receiver);
let fut = async move {
client.start().await?;
let mut notify = client.get_notify().await?;
let (mut server_rx, mut server_tx) = socket.split();
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
let mut rx_buf = [0u8; RX_BUF_SIZE];
let mut handshake_done = false;
let mut client_closed = false;
let mut server_closed = false;
let mut sent = Vec::with_capacity(1024);
let mut recv = Vec::with_capacity(1024);
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
// We don't start writing application data until the handshake is complete.
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
// Runs both the tx and rx halves of the connection to completion.
// This loop does not terminate until the *SERVER* closes the connection and
// we've processed all received data. If an error occurs, the `TlsConnection`
// channels will be closed and the error will be returned from this future.
'conn: loop {
// Write all pending TLS data to the server.
if client.wants_write() && !client_closed {
#[cfg(feature = "tracing")]
trace!("client wants to write");
while client.wants_write() {
let _sent = client.write_tls_async(&mut server_tx).await?;
#[cfg(feature = "tracing")]
trace!("sent {} tls bytes to server", _sent);
}
server_tx.flush().await?;
}
// Forward received plaintext to `TlsConnection`.
while !client.plaintext_is_empty() {
let read = client.read_plaintext(&mut rx_buf)?;
recv.extend(&rx_buf[..read]);
// Ignore if the receiver has hung up.
_ = rx_sender
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
.await;
#[cfg(feature = "tracing")]
trace!("forwarded {} plaintext bytes to conn", read);
}
if !client.is_handshaking() && !handshake_done {
#[cfg(feature = "tracing")]
debug!("handshake complete");
handshake_done = true;
// Start reading application data that needs to be transmitted from the
// `TlsConnection`.
tx_recv_fut = tx_receiver.next().fuse();
}
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
break 'conn;
}
select_biased! {
// Reads TLS data from the server and writes it into the client.
received = &mut rx_tls_fut => {
let received = received?;
#[cfg(feature = "tracing")]
trace!("received {} tls bytes from server", received);
// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
let mut reader = rx_tls_buf[..received].reader();
loop {
processed += client.read_tls(&mut reader)?;
client.process_new_packets().await?;
debug_assert!(processed <= received);
if processed >= received {
break;
}
}
#[cfg(feature = "tracing")]
trace!("processed {} tls bytes from server", processed);
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket.
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
server_closed = true;
client.server_closed().await?;
// Do not read from the socket again.
rx_tls_fut = Fuse::terminated();
} else {
// Reset the read future so next iteration we can read again.
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
}
}
// If we receive None from `TlsConnection`, it has closed, so we
// send a close_notify to the server.
data = &mut tx_recv_fut => {
if let Some(data) = data {
#[cfg(feature = "tracing")]
trace!("writing {} plaintext bytes to client", data.len());
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {
if !server_closed {
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
#[cfg(feature = "tracing")]
warn!("failed to send close_notify to server: {}", e);
}
}
client_closed = true;
tx_recv_fut = Fuse::terminated();
}
}
// Waits for a notification from the backend that it is ready to decrypt data.
_ = &mut notify => {
#[cfg(feature = "tracing")]
trace!("backend is ready to decrypt");
client.process_new_packets().await?;
}
}
}
#[cfg(feature = "tracing")]
debug!("client shutdown");
_ = server_tx.close().await;
tx_receiver.close();
rx_sender.close_channel();
#[cfg(feature = "tracing")]
trace!(
"server close notify: {}, sent: {}, recv: {}",
client.received_close_notify(),
sent.len(),
recv.len()
);
Ok(ClosedConnection { client, sent, recv })
};
#[cfg(feature = "tracing")]
let fut = fut.instrument(debug_span!("tls_connection"));
let fut = ConnectionFuture { fut: Box::pin(fut) };
(conn, fut)
}
async fn send_close_notify(
client: &mut ClientConnection,
server_tx: &mut (impl AsyncWrite + Unpin),
) -> Result<(), ConnectionError> {
#[cfg(feature = "tracing")]
trace!("sending close_notify to server");
client.send_close_notify().await?;
client.process_new_packets().await?;
// Flush all remaining plaintext
while client.wants_write() {
client.write_tls_async(server_tx).await?;
}
server_tx.flush().await?;
Ok(())
}

View File

@@ -1,438 +0,0 @@
use std::{str, sync::Arc};
use core::future::Future;
use futures::{AsyncReadExt, AsyncWriteExt};
use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
SERVER_DOMAIN,
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
client_tls_conn: TlsConnection,
// a handle that must be `.await`ed to get the result of a TLS connection
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
}
// Sets up a TLS connection between client and server and sends a hello message
#[fixture]
async fn set_up_tls() -> TlsFixture {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
client_tls_conn
.write_all(&pad("expecting you to send back hello".to_string()))
.await
.unwrap();
// give the server some time to respond
std::thread::sleep(std::time::Duration::from_millis(10));
let mut plaintext = vec![0u8; 320];
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
let s = str::from_utf8(&plaintext[0..n]).unwrap();
assert_eq!(s, "hello");
TlsFixture {
client_tls_conn,
closed_tls_task,
}
}
// Expect the async tls client wrapped in `hyper::client` to make a successful
// request and receive the expected response
#[tokio::test]
async fn test_hyper_ok() {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
.await
.unwrap();
tokio::spawn(connection);
let request = Request::builder()
.uri(format!("https://{SERVER_DOMAIN}/echo"))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Full::<Bytes>::new("hello".into()))
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Process the response body
response.into_body().collect().await.unwrap().to_bytes();
let _ = server_task.await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us AND close the socket
// after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify_and_socket_close(
set_up_tls: impl Future<Output = TlsFixture>,
) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect to be able to read the data after server closes the socket abruptly
#[rstest]
#[tokio::test]
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to send us a hello message
client_tls_conn
.write_all(&pad("send a hello message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to read some more data
let mut buf = vec![0u8; 10];
let n = client_tls_conn.read(&mut buf).await.unwrap();
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
}
// Expect there to be no error when server DOES NOT send close_notify but just
// closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(!closed_conn.client.received_close_notify());
}
// Expect to register a delay when the server delays closing the socket
#[rstest]
#[tokio::test]
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
client_tls_conn
.write_all(&pad("must_delay_when_closing".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client
client_tls_conn.close().await.unwrap();
use std::time::Instant;
let now = Instant::now();
// this will resolve when the server stops delaying closing the socket
let closed_conn = closed_tls_task.await.unwrap().unwrap();
let elapsed = now.elapsed();
// the elapsed time must be roughly equal to the server's delay
// (give or take timing variations)
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
assert!(!closed_conn.client.received_close_notify());
}
// Expect client to error when server sends a corrupted message
#[rstest]
#[tokio::test]
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send a corrupted message
client_tls_conn
.write_all(&pad("send_corrupted_message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received corrupt message"
);
}
// Expect client to error when server sends a TLS record with a bad MAC
#[rstest]
#[tokio::test]
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_record_with_bad_mac".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"backend error: Decryption error: \"aead::Error\""
);
}
// Expect client to error when server sends a fatal alert
#[rstest]
#[tokio::test]
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_alert".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received fatal alert: BadRecordMac"
);
}
// Expect an error when trying to write data to a connection which server closed
// abruptly
#[rstest]
#[tokio::test]
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to send some more data
let res = client_tls_conn
.write_all(&pad("more data".to_string()))
.await;
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
}
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
fn pad(s: String) -> Vec<u8> {
assert!(s.len() <= APP_RECORD_LENGTH);
let mut buf = vec![0u8; APP_RECORD_LENGTH];
buf[..s.len()].copy_from_slice(s.as_bytes());
buf
}

View File

@@ -227,6 +227,7 @@ impl ConnectionCommon {
/// Signals that the server has closed the connection.
pub async fn server_closed(&mut self) -> Result<(), Error> {
self.common_state.has_seen_eof = true;
self.common_state.backend.server_closed().await?;
Ok(())
}
@@ -457,6 +458,9 @@ impl ConnectionCommon {
return Err(Error::CorruptMessage);
}
// Process outgoing plaintext buffer and encrypt messages.
self.flush_plaintext().await?;
// Process new messages.
while let Some(msg) = self.message_deframer.frames.pop_front() {
// If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -508,25 +512,22 @@ impl ConnectionCommon {
Ok(state)
}
/// Write buffer into connection.
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state).await;
/// Writes plaintext `buf` into an internal buffer. May not fully process the
/// whole buffer and returns the processed length.
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if buf.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.common_state.send_some_plaintext(buf).await
let len = self.sendable_plaintext.append_limited_copy(buf);
Ok(len)
}
/// Write entire buffer into connection.
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
while pos < buf.len() {
pos += self.write_plaintext(&buf[pos..]).await?;
}
self.backend.flush().await?;
while let Some(msg) = self.backend.next_outgoing().await? {
self.queue_tls_message(msg);
}
Ok(pos)
/// Writes the entire plaintext `buf` into an internal buffer.
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
self.sendable_plaintext.append(buf.to_vec());
Ok(())
}
/// Read TLS content from `rd`. This method does internal
@@ -690,6 +691,11 @@ impl CommonState {
self.received_plaintext.is_empty()
}
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS
/// handshake.
///
@@ -782,15 +788,6 @@ impl CommonState {
}
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake
// messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections,
@@ -931,32 +928,6 @@ impl CommonState {
self.sendable_tls.write_to_async(wr).await
}
/// Encrypt and send some plaintext `data`. `limit` controls
/// whether the per-connection buffer limits apply.
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
let len = match limit {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true;
self.flush_plaintext().await
@@ -1012,15 +983,14 @@ impl CommonState {
self.sendable_tls.set_limit(limit);
}
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
async fn flush_plaintext(&mut self) -> Result<(), Error> {
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data {
return Ok(());
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_plain(&buf, Limit::No).await?;
self.send_appdata_encrypt(&buf, Limit::No).await?;
}
Ok(())

View File

@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
self.chunks.is_empty()
}
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing
pub(crate) fn len(&self) -> usize {
let mut len = 0;

View File

@@ -247,7 +247,8 @@ async fn servered_client_data_sent() {
let (mut client, mut server) =
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
client.flush_plaintext().await.unwrap();
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -286,7 +287,7 @@ async fn servered_both_data_sent() {
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
do_handshake(&mut client, &mut server).await;
@@ -432,7 +433,7 @@ async fn server_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
server.send_close_notify();
receive(&mut server, &mut client);
@@ -460,7 +461,8 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.flush_plaintext().await.unwrap();
client.send_close_notify().await.unwrap();
send(&mut client, &mut server);
@@ -487,7 +489,7 @@ async fn server_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
receive(&mut server, &mut client);
transfer_eof(&mut client);
@@ -518,7 +520,7 @@ async fn client_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.process_new_packets().await.unwrap();
send(&mut client, &mut server);
@@ -900,20 +902,9 @@ async fn client_respects_buffer_limit_pre_handshake() {
client.set_buffer_limit(Some(32));
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
12
);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
client.flush_plaintext().await.unwrap();
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -953,20 +944,9 @@ async fn client_respects_buffer_limit_post_handshake() {
do_handshake(&mut client, &mut server).await;
client.set_buffer_limit(Some(48));
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
6
);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
client.flush_plaintext().await.unwrap();
send(&mut client, &mut server);
server.process_new_packets().unwrap();
@@ -1211,14 +1191,8 @@ async fn client_complete_io_for_write() {
do_handshake(&mut client, &mut server).await;
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let (rdlen, wrlen) = client
@@ -1350,7 +1324,8 @@ async fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap();
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ClientSession::new(&mut client);
@@ -1366,7 +1341,8 @@ async fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap();
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{
let pipe = ClientSession::new(&mut client);
@@ -1385,7 +1361,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 0,
// };
// client.write_all_plaintext(b"hello").await.unwrap();
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert!(rc.is_err());
@@ -1402,7 +1380,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 1,
// };
// client.write_all_plaintext(b"hello").await.unwrap();
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1900,14 +1880,9 @@ async fn servered_write_for_client_appdata() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
do_handshake(&mut client, &mut server).await;
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -2019,11 +1994,10 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
async fn servered_write_for_client_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"0123456789").await.unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"0123456789").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();

View File

@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
tlsn-core = { workspace = true }
tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true }
tlsn-cipher = { workspace = true }
futures-plex = { workspace = true }
serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true }

View File

@@ -22,6 +22,9 @@ use std::sync::LazyLock;
use semver::Version;
// Size for internal buffers.
const BUF_CAP: usize = 8 * 1024;
// Package version.
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")

View File

@@ -1,31 +1,31 @@
//! Prover.
mod client;
mod error;
mod future;
mod prove;
pub mod state;
pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput;
use crate::{
Role,
BUF_CAP, Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
prover::client::{MpcTlsClient, TlsOutput},
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use futures::{FutureExt, TryFutureExt};
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc;
use std::{
io::{Read, Write},
sync::Arc,
task::{Context, Poll},
};
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
config::{
prove::ProveConfig,
@@ -36,10 +36,9 @@ use tlsn_core::{
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -71,15 +70,16 @@ impl Prover<state::Initialized> {
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn commit(
self,
config: TlsCommitConfig,
socket: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Sends protocol configuration to verifier for compatibility check.
@@ -118,47 +118,48 @@ impl Prover<state::Initialized> {
debug!("mpc-tls setup complete");
Ok(Prover {
let prover = Prover {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
mpc_tls,
keys,
vm,
},
})
};
Ok(prover)
}
}
impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket.
/// Connects the prover.
///
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed and the TLS transcript is
/// committed.
/// Returns a connected prover, which can be used to read and write from/to
/// the active TLS connection.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn connect(
self,
config: TlsClientConfig,
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
) -> Result<Prover<state::Connected>, ProverError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mux_fut,
mpc_tls,
keys,
vm,
..
} = self.state;
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let decrypt = mpc_tls.is_decrypting();
let ServerName::Dns(server_name) = config.server_name();
let server_name =
@@ -195,102 +196,194 @@ impl Prover<state::CommitAccepted> {
rustls_config.with_no_client_auth()
};
let client = ClientConnection::new(
Arc::new(rustls_config),
Box::new(mpc_ctrl.clone()),
server_name,
)
.map_err(ProverError::config)?;
let client = ClientConnection::new(Arc::new(rustls_config), Box::new(mpc_tls), server_name)
.map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client);
let span = self.span.clone();
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
let mpc_tls = MpcTlsClient::new(keys, vm, span, client, decrypt);
mpc_ctrl.stop().await?;
Ok::<_, ProverError>(())
};
info!("starting MPC-TLS");
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
info!("finished MPC-TLS");
{
let mut vm = vm.try_lock().expect("VM should not be locked");
debug!("finalizing mpc");
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
.instrument(span)
});
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
let prover = Prover {
config: self.config,
span: self.span,
state: state::Connected {
mpc_duplex,
mux_ctrl,
mux_fut,
server_name: config.server_name().clone(),
tls_client: Box::new(mpc_tls),
output: None,
},
))
};
Ok(prover)
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Prover<state::Connected> {
/// Returns `true` if the prover wants to read TLS data from the server.
pub fn wants_read_tls(&self) -> bool {
self.state.tls_client.wants_read_tls()
}
/// Returns `true` if the prover wants to write TLS data to the server.
pub fn wants_write_tls(&self) -> bool {
self.state.tls_client.wants_write_tls()
}
/// Reads TLS data from the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the TLS data from.
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.read_tls(buf)
}
/// Writes TLS data for the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer to write the TLS data to.
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.write_tls(buf)
}
/// Returns `true` if the prover wants to read plaintext data.
pub fn wants_read(&self) -> bool {
self.state.tls_client.wants_read()
}
/// Returns `true` if the prover wants to write plaintext data.
pub fn wants_write(&self) -> bool {
self.state.tls_client.wants_write()
}
/// Reads plaintext data from the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer where the plaintext data gets written to.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.read(buf)
}
/// Writes plaintext data to be sent to the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the plaintext data from.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.write(buf)
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Closes the connection from the client side.
pub fn client_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.client_close()
}
/// Closes the connection from the server side.
pub fn server_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.server_close()
}
/// Enables or disables the decryption of data from the server until the
/// server has closed the connection.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
self.state.tls_client.enable_decryption(enable)
}
/// Returns `true` if decryption of TLS traffic from the server is active.
pub fn is_decrypting(&self) -> bool {
self.state.tls_client.is_decrypting()
}
/// Polls the prover to make progress.
///
/// # Arguments
///
/// * `cx` - The async context.
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
let _ = self.state.mux_fut.poll_unpin(cx)?;
match self.state.tls_client.poll(cx)? {
Poll::Ready(output) => {
let _ = self.state.mux_fut.poll_unpin(cx)?;
self.state.output = Some(output);
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex: self.state.mpc_duplex,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
},
};
Ok(prover)
}
}
@@ -305,6 +398,24 @@ impl Prover<state::Committed> {
&self.state.transcript
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Proves information to the verifier.
///
/// # Arguments
@@ -367,41 +478,19 @@ impl Prover<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), ProverError> {
let state::Committed {
mux_ctrl, mux_fut, ..
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
} = self.state;
// Wait for the verifier to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
}
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -0,0 +1,63 @@
//! Provides a TLS client.
use crate::mpz::ProverZk;
use mpc_tls::SessionKeys;
use std::task::{Context, Poll};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Enables or disables decryption of TLS traffic sent by the server.
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
/// Returns `true` if decryption of TLS traffic from the server is active.
fn is_decrypting(&self) -> bool;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -0,0 +1,391 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{collections::VecDeque, pin::Pin, sync::Arc, task::Poll};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture =
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
pub(crate) struct MpcTlsClient {
state: State,
decrypt: bool,
cmds: VecDeque<Command>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum Command {
ClientClose,
ServerClose,
Decrypt(bool),
}
enum State {
Start {
inner: Box<InnerState>,
},
Active {
inner: Box<InnerState>,
},
Busy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
CloseActive {
inner: Box<InnerState>,
},
CloseBusy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
tls: ClientConnection,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_stopped: false,
};
let decrypt = tls.backend().is_decrypting();
Self {
state: State::Start {
inner: Box::new(inner),
},
decrypt,
cmds: VecDeque::default(),
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ClientClose);
Ok(())
}
fn server_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ServerClose);
Ok(())
}
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
self.cmds.push_back(Command::Decrypt(enable));
Ok(())
}
fn is_decrypting(&self) -> bool {
self.decrypt
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start { inner } => {
trace!("inner client is starting");
self.state = State::Busy {
fut: Box::pin(inner.start()),
};
self.poll(cx)
}
State::Active { mut inner } => {
trace!("inner client is active");
if !inner.tls.is_handshaking()
&& let Some(cmd) = self.cmds.pop_front()
{
match cmd {
Command::ClientClose => {
self.state = State::Busy {
fut: Box::pin(inner.client_close()),
};
}
Command::ServerClose => {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
Command::Decrypt(enable) => {
inner.tls.backend_mut().enable_decryption(enable)?;
self.decrypt = enable;
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
}
} else {
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
self.poll(cx)
}
State::Busy { mut fut } => {
trace!("inner client is busy");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active { inner };
}
Poll::Pending => self.state = State::Busy { fut },
}
Poll::Pending
}
State::CloseActive { mut inner } => {
trace!("inner client is close active");
if let Some((ctx, transcript)) = inner.tls.backend_mut().finish() {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
} else {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
self.poll(cx)
}
State::CloseBusy { mut fut } => {
trace!("inner client is busy closing");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::CloseActive { inner };
}
Poll::Pending => self.state = State::CloseBusy { fut },
}
Poll::Pending
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
Poll::Pending
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_stopped: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
debug!("sending close notify");
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
self.tls.server_closed().await?;
self.mpc_stopped = true;
debug!("closed connection serverside");
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -49,6 +49,13 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
}
#[derive(Debug)]
@@ -58,6 +65,7 @@ enum ErrorKind {
Zk,
Config,
Commit,
State,
}
impl fmt::Display for ProverError {
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
}
if let Some(source) = &self.source {
@@ -86,8 +95,8 @@ impl From<std::io::Error> for ProverError {
}
}
impl From<tls_client_async::ConnectionError> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self {
impl From<tls_client::Error> for ProverError {
fn from(e: tls_client::Error) -> Self {
Self::new(ErrorKind::Io, e)
}
}
@@ -115,3 +124,9 @@ impl From<EncodingError> for ProverError {
Self::new(ErrorKind::Commit, e)
}
}
impl From<ProverError> for std::io::Error {
fn from(value: ProverError) -> Self {
Self::other(value)
}
}

View File

@@ -1,32 +0,0 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -2,6 +2,7 @@
use std::sync::Arc;
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -14,6 +15,10 @@ use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
};
/// Entry state
@@ -24,6 +29,7 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
@@ -33,8 +39,21 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted);
/// State during the MPC-TLS connection.
pub struct Connected {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
pub(crate) output: Option<TlsOutput>,
}
opaque_debug::implement!(Connected);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -52,11 +71,13 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for CommitAccepted {}
impl ProverState for Connected {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Connected {}
impl Sealed for super::Committed {}
}

View File

@@ -10,16 +10,17 @@ pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{
Role,
BUF_CAP, Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use futures::TryFutureExt;
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use std::io::{Read, Write};
use tlsn_core::{
config::{
prove::ProveRequest,
@@ -68,11 +69,10 @@ impl Verifier<state::Initialized> {
///
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
pub async fn commit(self) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -102,6 +102,7 @@ impl Verifier<state::Initialized> {
config: self.config,
span: self.span,
state: state::CommitStart {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
ctx,
@@ -121,6 +122,7 @@ impl Verifier<state::CommitStart> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -151,6 +153,7 @@ impl Verifier<state::CommitStart> {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mux_fut,
mpc_tls,
@@ -182,6 +185,24 @@ impl Verifier<state::CommitStart> {
Ok(())
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::CommitAccepted> {
@@ -189,6 +210,7 @@ impl Verifier<state::CommitAccepted> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -245,6 +267,7 @@ impl Verifier<state::CommitAccepted> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -254,6 +277,24 @@ impl Verifier<state::CommitAccepted> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::Committed> {
@@ -266,6 +307,7 @@ impl Verifier<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -286,6 +328,7 @@ impl Verifier<state::Committed> {
config: self.config,
span: self.span,
state: state::Verify {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -299,17 +342,39 @@ impl Verifier<state::Committed> {
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Closes the connection with the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), VerifierError> {
let state::Committed {
mux_ctrl, mux_fut, ..
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
} = self.state;
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
@@ -327,6 +392,7 @@ impl Verifier<state::Verify> {
self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -362,6 +428,7 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -379,6 +446,7 @@ impl Verifier<state::Verify> {
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -396,6 +464,7 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -405,4 +474,22 @@ impl Verifier<state::Verify> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}

View File

@@ -3,6 +3,7 @@
use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -25,6 +26,7 @@ opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover.
pub struct CommitStart {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -36,6 +38,7 @@ opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower,
@@ -47,6 +50,7 @@ opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -59,6 +63,7 @@ opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,

View File

@@ -23,7 +23,6 @@ no-bundler = ["web-spawn/no-bundler"]
tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
bincode = { workspace = true }