Compare commits

..

4 Commits

Author SHA1 Message Date
Hendrik Eeckhaut
b76775fc7c correction + legend placement 2025-12-23 15:11:35 +01:00
Hendrik Eeckhaut
72041d1f07 export dark svg 2025-12-23 14:47:19 +01:00
Hendrik Eeckhaut
ac1df8fc75 Allow plotting multiple data runs 2025-12-23 14:31:54 +01:00
Hendrik Eeckhaut
3cb7c5c0b4 Working on benchmark plots 2025-12-23 14:07:39 +01:00
56 changed files with 4033 additions and 2616 deletions

1705
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ members = [
"crates/server-fixture/server",
"crates/tls/backend",
"crates/tls/client",
"crates/tls/client-async",
"crates/tls/core",
"crates/mpc-tls",
"crates/tls/server-fixture",
@@ -56,6 +57,7 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-harness-core = { path = "crates/harness/core" }
@@ -80,7 +82,6 @@ mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "c210f2f" }
rangeset = { version = "0.4" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }

View File

@@ -87,15 +87,13 @@ async fn main() -> Result<()> {
}
async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
verifier_socket: S,
socket: S,
req_tx: Sender<AttestationRequest>,
resp_rx: Receiver<Attestation>,
uri: &str,
extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType,
) -> Result<()> {
let mut verifier_socket = verifier_socket.compat();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
@@ -117,36 +115,37 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.build()?,
)
.build()?,
&mut verifier_socket,
socket.compat(),
)
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port))
.await?
.compat();
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
// Bind the prover to the server connection.
let (tls_connection, prover) = prover.setup(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
))
.build()?,
)?;
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
))
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
let prover_task = tokio::spawn(prover_fut);
// Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) =
@@ -181,7 +180,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it.
let (prover, _, verifier_socket) = prover_task.await??;
let prover = prover_task.await??;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?;
@@ -223,8 +222,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let request_config = builder.build()?;
let (attestation, secrets) =
notarize(prover, &request_config, verifier_socket, req_tx, resp_rx).await?;
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
// Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
@@ -244,10 +242,9 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
Ok(())
}
async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
async fn notarize(
mut prover: Prover<Committed>,
config: &RequestConfig,
mut verifier_socket: S,
request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets)> {
@@ -263,13 +260,11 @@ async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
transcript_commitments,
transcript_secrets,
..
} = prover
.prove(&disclosure_config, &mut verifier_socket)
.await?;
} = prover.prove(&disclosure_config).await?;
let transcript = prover.transcript().clone();
let tls_transcript = prover.tls_transcript().clone();
prover.close(&mut verifier_socket).await?;
prover.close().await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);
@@ -312,12 +307,10 @@ async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
}
async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
prover_socket: S,
socket: S,
request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>,
) -> Result<()> {
let mut prover_socket = prover_socket.compat();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -329,11 +322,11 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap();
let verifier = Verifier::new(verifier_config)
.commit(&mut prover_socket)
.commit(socket.compat())
.await?
.accept(&mut prover_socket)
.accept()
.await?
.run(&mut prover_socket)
.run()
.await?;
let (
@@ -343,15 +336,11 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
..
},
verifier,
) = verifier
.verify(&mut prover_socket)
.await?
.accept(&mut prover_socket)
.await?;
) = verifier.verify().await?.accept().await?;
let tls_transcript = verifier.tls_transcript().clone();
verifier.close(&mut prover_socket).await?;
verifier.close().await?;
let sent_len = tls_transcript
.sent()

View File

@@ -73,8 +73,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
server_addr: &SocketAddr,
uri: &str,
) -> Result<()> {
let mut verifier_socket = verifier_socket.compat();
let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host();
@@ -95,30 +93,32 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.build()?,
)
.build()?,
&mut verifier_socket,
verifier_socket.compat(),
)
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Bind the prover to the server connection.
let (tls_connection, prover) = prover.setup(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
)?;
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
@@ -140,7 +140,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier.
let (mut prover, _, mut verifier_socket) = prover_task.await??;
let mut prover = prover_task.await??;
let mut builder = ProveConfig::builder(prover.transcript());
@@ -173,8 +173,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build()?;
prover.prove(&config, &mut verifier_socket).await?;
prover.close(&mut verifier_socket).await?;
prover.prove(&config).await?;
prover.close().await?;
Ok(())
}
@@ -183,8 +183,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> Result<PartialTranscript> {
let mut socket = socket.compat();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -196,7 +194,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let verifier = Verifier::new(verifier_config);
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(&mut socket).await?;
let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
@@ -214,21 +212,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
};
if reject.is_some() {
verifier.reject(&mut socket, reject).await?;
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Runs the TLS commitment protocol to completion.
let verifier = verifier.accept(&mut socket).await?.run(&mut socket).await?;
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify(&mut socket).await?;
let verifier = verifier.verify().await?;
if !verifier.request().server_identity() {
let verifier = verifier
.reject(&mut socket, Some("expecting to verify the server name"))
.reject(Some("expecting to verify the server name"))
.await?;
verifier.close(&mut socket).await?;
verifier.close().await?;
return Err(anyhow::anyhow!("prover did not reveal the server name"));
}
@@ -239,9 +237,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
..
},
verifier,
) = verifier.accept(&mut socket).await?;
) = verifier.accept().await?;
verifier.close(&mut socket).await?;
verifier.close().await?;
let server_name = server_name.expect("prover should have revealed server name");
let transcript = transcript.expect("prover should have revealed transcript data");

View File

@@ -31,10 +31,11 @@ async fn main() -> Result<()> {
// Connect prover and verifier.
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
let (_, transcript) = tokio::try_join!(
prover(prover_socket, &server_addr, &uri),
verifier(verifier_socket)
prover(prover_socket, prover_extra_socket, &server_addr, &uri),
verifier(verifier_socket, verifier_extra_socket)
)?;
println!("---");

View File

@@ -46,15 +46,15 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
#[instrument(skip(verifier_socket))]
#[instrument(skip(verifier_socket, verifier_extra_socket))]
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
mut verifier_extra_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<()> {
let mut verifier_socket = verifier_socket.compat();
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
}
@@ -80,29 +80,32 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.build()?,
)
.build()?,
&mut verifier_socket,
verifier_socket.compat(),
)
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Bind the prover to the server connection.
let (tls_connection, prover) = prover.setup(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
)?;
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
@@ -130,7 +133,7 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
}
// Create proof for the Verifier.
let (mut prover, _, mut verifier_socket) = prover_task.await??;
let mut prover = prover_task.await??;
let transcript = prover.transcript().clone();
let mut prove_config_builder = ProveConfig::builder(&transcript);
@@ -164,8 +167,8 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let prove_config = prove_config_builder.build()?;
// MPC-TLS prove
let prover_output = prover.prove(&prove_config, &mut verifier_socket).await?;
prover.close(&mut verifier_socket).await?;
let prover_output = prover.prove(&prove_config).await?;
prover.close().await?;
// Prove birthdate is more than 18 years ago.
let received_commitments = received_commitments(&prover_output.transcript_commitments);
@@ -181,10 +184,8 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
// Sent zk proof bundle to verifier
let serialized_proof = bincode::serialize(&proof_bundle)?;
let mut verifier_socket = verifier_socket.into_inner();
verifier_socket.write_all(&serialized_proof).await?;
verifier_socket.shutdown().await?;
verifier_extra_socket.write_all(&serialized_proof).await?;
verifier_extra_socket.shutdown().await?;
Ok(())
}

View File

@@ -20,12 +20,11 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::instrument;
#[instrument(skip(prover_socket))]
#[instrument(skip(socket, extra_socket))]
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
prover_socket: T,
socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript> {
let mut prover_socket = prover_socket.compat();
let verifier = Verifier::new(
VerifierConfig::builder()
// Create a root certificate store with the server-fixture's self-signed
@@ -38,7 +37,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
);
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(&mut prover_socket).await?;
let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
@@ -56,29 +55,24 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
};
if reject.is_some() {
verifier.reject(&mut prover_socket, reject).await?;
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Runs the TLS commitment protocol to completion.
let verifier = verifier
.accept(&mut prover_socket)
.await?
.run(&mut prover_socket)
.await?;
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify(&mut prover_socket).await?;
let verifier = verifier.verify().await?;
let request = verifier.request();
if !request.server_identity() || request.reveal().is_none() {
let verifier = verifier
.reject(
&mut prover_socket,
Some("expecting to verify the server name and transcript data"),
)
.reject(Some(
"expecting to verify the server name and transcript data",
))
.await?;
verifier.close(&mut prover_socket).await?;
verifier.close().await?;
return Err(anyhow::anyhow!(
"prover did not reveal the server name and transcript data"
));
@@ -92,9 +86,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
..
},
verifier,
) = verifier.accept(&mut prover_socket).await?;
) = verifier.accept().await?;
verifier.close().await?;
verifier.close(&mut prover_socket).await?;
let server_name = server_name.expect("server name should be present");
let transcript = transcript.expect("transcript should be present");
@@ -131,9 +126,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Receive ZKProof information from prover
let mut buf = Vec::new();
let mut prover_socket = prover_socket.into_inner();
prover_socket.read_to_end(&mut buf).await?;
extra_socket.read_to_end(&mut buf).await?;
if buf.is_empty() {
return Err(anyhow::anyhow!("No ZK proof data received from prover"));

View File

@@ -1,59 +1,51 @@
#### Default Representative Benchmarks ####
#
# This benchmark measures TLSNotary performance on three representative network scenarios.
# Each scenario is run multiple times to produce statistical metrics (median, std dev, etc.)
# rather than plots. Use this for quick performance checks and CI regression testing.
#
# Payload sizes:
# - upload-size: 1KB (typical HTTP request)
# - download-size: 2KB (typical HTTP response/API data)
#
# Network scenarios are chosen to represent real-world user conditions where
# TLSNotary is primarily bottlenecked by upload bandwidth.
#### Cable/DSL Home Internet ####
# Most common residential internet connection
# - Asymmetric: high download, limited upload (typical bottleneck)
# - Upload bandwidth: 20 Mbps (realistic cable/DSL upload speed)
# - Latency: 20ms (typical ISP latency)
#### Latency ####
[[group]]
name = "cable"
bandwidth = 20
protocol_latency = 20
upload-size = 1024
download-size = 2048
name = "latency"
bandwidth = 1000
[[bench]]
group = "cable"
#### Mobile 5G ####
# Modern mobile connection with good coverage
# - Upload bandwidth: 30 Mbps (typical 5G upload in good conditions)
# - Latency: 30ms (higher than wired due to mobile tower hops)
[[group]]
name = "mobile_5g"
bandwidth = 30
protocol_latency = 30
upload-size = 1024
download-size = 2048
group = "latency"
protocol_latency = 10
[[bench]]
group = "mobile_5g"
group = "latency"
protocol_latency = 25
#### Fiber Home Internet ####
# High-end residential connection (best case scenario)
# - Symmetric: equal upload/download bandwidth
# - Upload bandwidth: 100 Mbps (typical fiber upload)
# - Latency: 15ms (lower latency than cable)
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200
#### Bandwidth ####
[[group]]
name = "fiber"
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
protocol_latency = 15
upload-size = 1024
download-size = 2048
[[bench]]
group = "fiber"
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -1,52 +0,0 @@
#### Bandwidth Sweep Benchmark ####
#
# Measures how network bandwidth affects TLSNotary runtime.
# Keeps latency and payload sizes fixed while varying upload bandwidth.
#
# Fixed parameters:
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Bandwidth from 5 Mbps to 1000 Mbps
#
# Use this to plot "Bandwidth vs Runtime" and understand bandwidth sensitivity.
# Focus on upload bandwidth as TLSNotary is primarily upload-bottlenecked
[[group]]
name = "bandwidth_sweep"
protocol_latency = 25
upload-size = 1024
download-size = 2048
[[bench]]
group = "bandwidth_sweep"
bandwidth = 5
[[bench]]
group = "bandwidth_sweep"
bandwidth = 10
[[bench]]
group = "bandwidth_sweep"
bandwidth = 20
[[bench]]
group = "bandwidth_sweep"
bandwidth = 50
[[bench]]
group = "bandwidth_sweep"
bandwidth = 100
[[bench]]
group = "bandwidth_sweep"
bandwidth = 250
[[bench]]
group = "bandwidth_sweep"
bandwidth = 500
[[bench]]
group = "bandwidth_sweep"
bandwidth = 1000

View File

@@ -1,53 +0,0 @@
#### Download Size Sweep Benchmark ####
#
# Measures how download payload size affects TLSNotary runtime.
# Keeps network conditions fixed while varying the response size.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request size)
#
# Variable: Download size from 1KB to 100KB
#
# Use this to plot "Download Size vs Runtime" and understand how much data
# TLSNotary can efficiently notarize. Useful for determining optimal
# chunking strategies for large responses.
[[group]]
name = "download_sweep"
bandwidth = 100
protocol_latency = 25
upload-size = 1024
[[bench]]
group = "download_sweep"
download-size = 1024
[[bench]]
group = "download_sweep"
download-size = 2048
[[bench]]
group = "download_sweep"
download-size = 5120
[[bench]]
group = "download_sweep"
download-size = 10240
[[bench]]
group = "download_sweep"
download-size = 20480
[[bench]]
group = "download_sweep"
download-size = 30720
[[bench]]
group = "download_sweep"
download-size = 40960
[[bench]]
group = "download_sweep"
download-size = 51200

View File

@@ -1,47 +0,0 @@
#### Latency Sweep Benchmark ####
#
# Measures how network latency affects TLSNotary runtime.
# Keeps bandwidth and payload sizes fixed while varying protocol latency.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Protocol latency from 10ms to 200ms
#
# Use this to plot "Latency vs Runtime" and understand latency sensitivity.
[[group]]
name = "latency_sweep"
bandwidth = 100
upload-size = 1024
download-size = 2048
[[bench]]
group = "latency_sweep"
protocol_latency = 10
[[bench]]
group = "latency_sweep"
protocol_latency = 25
[[bench]]
group = "latency_sweep"
protocol_latency = 50
[[bench]]
group = "latency_sweep"
protocol_latency = 75
[[bench]]
group = "latency_sweep"
protocol_latency = 100
[[bench]]
group = "latency_sweep"
protocol_latency = 150
[[bench]]
group = "latency_sweep"
protocol_latency = 200

View File

@@ -23,8 +23,7 @@ use crate::{
};
pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<ProverMetrics> {
let mut verifier_io = Meter::new(provider.provide_proto_io().await?);
let mut server_io = provider.provide_server_io().await?;
let verifier_io = Meter::new(provider.provide_proto_io().await?);
let sent = verifier_io.sent();
let recv = verifier_io.recv();
@@ -50,7 +49,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
.build()
}?)
.build()?,
&mut verifier_io,
verifier_io,
)
.await?;
@@ -59,18 +58,19 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let uploaded_preprocess = sent.load(Ordering::Relaxed);
let downloaded_preprocess = recv.load(Ordering::Relaxed);
let (mut conn, prover) = prover.setup(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
)?;
let (mut conn, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
provider.provide_server_io().await?,
)
.await?;
let mut prover = prover.connect(&mut server_io, &mut verifier_io);
futures::try_join!(
let (_, mut prover) = futures::try_join!(
async {
let request = format!(
"GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n",
@@ -87,9 +87,8 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
Ok(())
},
(&mut prover).map_err(anyhow::Error::from)
prover_fut.map_err(anyhow::Error::from)
)?;
let mut prover = prover.finish()?;
let time_online = time_start_online.elapsed().as_millis();
let uploaded_online = sent.load(Ordering::Relaxed) - uploaded_preprocess;
@@ -119,8 +118,8 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let prove_config = builder.build()?;
prover.prove(&prove_config, &mut verifier_io).await?;
prover.close(&mut verifier_io).await?;
prover.prove(&prove_config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -11,8 +11,6 @@ use tlsn_server_fixture_certs::CA_CERT_DER;
use crate::IoProvider;
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
let mut prover_io = provider.provide_proto_io().await?;
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
@@ -24,16 +22,12 @@ pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()
let verifier = verifier
.commit(provider.provide_proto_io().await?)
.await?
.accept(&mut prover_io)
.accept()
.await?
.run(&mut prover_io)
.run()
.await?;
let (_, verifier) = verifier
.verify(&mut prover_io)
.await?
.accept(&mut prover_io)
.await?;
verifier.close(&mut prover_io).await?;
let (_, verifier) = verifier.verify().await?.accept().await?;
verifier.close().await?;
Ok(())
}

View File

@@ -28,8 +28,6 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) {
let mut verifier_io = provider.provide_proto_io().await.unwrap();
let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit(
TlsCommitConfig::builder()
@@ -43,15 +41,13 @@ async fn prover(provider: &IoProvider) {
)
.build()
.unwrap(),
&mut verifier_io,
provider.provide_proto_io().await.unwrap(),
)
.await
.unwrap();
let server_io = provider.provide_server_io().await.unwrap();
let (tls_connection, prover) = prover
.setup(
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore {
@@ -59,10 +55,12 @@ async fn prover(provider: &IoProvider) {
})
.build()
.unwrap(),
provider.provide_server_io().await.unwrap(),
)
.await
.unwrap();
let prover_task = spawn(prover.run(server_io, verifier_io));
let prover_task = spawn(prover_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(FuturesIo::new(tls_connection))
@@ -89,7 +87,7 @@ async fn prover(provider: &IoProvider) {
let _ = response.into_body().collect().await.unwrap().to_bytes();
let (mut prover, _, mut verifier_io) = prover_task.await.unwrap().unwrap();
let mut prover = prover_task.await.unwrap().unwrap();
let (sent_len, recv_len) = prover.transcript().len();
@@ -116,13 +114,11 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap();
prover.prove(&config, &mut verifier_io).await.unwrap();
prover.close(&mut verifier_io).await.unwrap();
prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
}
async fn verifier(provider: &IoProvider) {
let mut prover_io = provider.provide_proto_io().await.unwrap();
let config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
@@ -131,13 +127,13 @@ async fn verifier(provider: &IoProvider) {
.unwrap();
let verifier = Verifier::new(config)
.commit(&mut prover_io)
.commit(provider.provide_proto_io().await.unwrap())
.await
.unwrap()
.accept(&mut prover_io)
.accept()
.await
.unwrap()
.run(&mut prover_io)
.run()
.await
.unwrap();
@@ -148,15 +144,9 @@ async fn verifier(provider: &IoProvider) {
..
},
verifier,
) = verifier
.verify(&mut prover_io)
.await
.unwrap()
.accept(&mut prover_io)
.await
.unwrap();
) = verifier.verify().await.unwrap().accept().await.unwrap();
verifier.close(&mut prover_io).await.unwrap();
verifier.close().await.unwrap();
let ServerName::Dns(server_name) = server_name.unwrap();

View File

@@ -7,10 +7,9 @@ publish = false
[dependencies]
tlsn-harness-core = { workspace = true }
# tlsn-server-fixture = { workspace = true }
charming = { version = "0.5.1", features = ["ssr"] }
csv = "1.3.0"
charming = { version = "0.6.0", features = ["ssr"] }
clap = { workspace = true, features = ["derive", "env"] }
itertools = "0.14.0"
polars = { version = "0.44", features = ["csv", "lazy"] }
toml = { workspace = true }

View File

@@ -0,0 +1,111 @@
# TLSNotary Benchmark Plot Tool
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
## Usage
```bash
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
```
### Arguments
- `<TOML>` - Path to Bench.toml file defining benchmark structure
- `<CSV>...` - One or more CSV files with benchmark results
### Options
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
- Number of labels must match number of CSV files
- `--min-max-band` - Add min/max bands to plots showing variance
- `-h, --help` - Print help information
## Examples
### Single Dataset
```bash
tlsn-harness-plot bench.toml results.csv
```
Generates plots from a single benchmark run.
### Compare Two Runs
```bash
tlsn-harness-plot bench.toml before.csv after.csv \
--labels "Before Optimization" "After Optimization"
```
Overlays two datasets to compare performance improvements.
### Multiple Datasets
```bash
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
--labels "Native" "Browser" "WASM"
```
Compare three different runtime environments.
### With Min/Max Bands
```bash
tlsn-harness-plot bench.toml run1.csv run2.csv \
--labels "Config A" "Config B" \
--min-max-band
```
Shows variance ranges for each dataset.
## Output Files
The tool generates two files per benchmark group:
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
- `<output>.svg` - Static SVG image for documentation
Default output filenames:
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
## Plot Format
Each dataset displays:
- **Solid line** - Total runtime (preprocessing + online phase)
- **Dashed line** - Online phase only
- **Shaded area** (optional) - Min/max variance bands
Different datasets automatically use distinct colors for easy comparison.
## CSV Format
Expected columns in each CSV file:
- `group` - Benchmark group name (must match TOML)
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
- `latency` - Network latency in ms (for latency plots)
- `time_preprocess` - Preprocessing time in ms
- `time_online` - Online phase time in ms
- `time_total` - Total runtime in ms
## TOML Format
The benchmark TOML file defines groups with either:
```toml
[[group]]
name = "my_benchmark"
protocol_latency = 50 # Fixed latency for bandwidth plots
# OR
bandwidth = 10000 # Fixed bandwidth for latency plots
```
All datasets must use the same TOML file to ensure consistent benchmark structure.
## Tips
- Use descriptive labels to make plots self-documenting
- Keep CSV files from the same benchmark configuration for valid comparisons
- Min/max bands are useful for showing stability but can clutter plots with many datasets
- Interactive HTML plots support zooming and hovering for detailed values

View File

@@ -1,17 +1,18 @@
use std::f32;
use charming::{
Chart, HtmlRenderer,
Chart, HtmlRenderer, ImageRenderer,
component::{Axis, Legend, Title},
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
element::{
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
Trigger,
},
series::Line,
theme::Theme,
};
use clap::Parser;
use harness_core::bench::{BenchItems, Measurement};
use itertools::Itertools;
const THEME: Theme = Theme::Default;
use harness_core::bench::BenchItems;
use polars::prelude::*;
#[derive(Parser, Debug)]
#[command(author, version, about)]
@@ -19,72 +20,131 @@ struct Cli {
/// Path to the Bench.toml file with benchmark spec
toml: String,
/// Path to the CSV file with benchmark results
csv: String,
/// Paths to CSV files with benchmark results (one or more)
csv: Vec<String>,
/// Prover kind: native or browser
#[arg(short, long, value_enum, default_value = "native")]
prover_kind: ProverKind,
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
#[arg(short, long, num_args = 0..)]
labels: Vec<String>,
/// Add min/max bands to plots
#[arg(long, default_value_t = false)]
min_max_band: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ProverKind {
Native,
Browser,
}
impl std::fmt::Display for ProverKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProverKind::Native => write!(f, "Native"),
ProverKind::Browser => write!(f, "Browser"),
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let mut rdr = csv::Reader::from_path(&cli.csv)?;
if cli.csv.is_empty() {
return Err("At least one CSV file must be provided".into());
}
// Generate labels if not provided
let labels: Vec<String> = if cli.labels.is_empty() {
cli.csv
.iter()
.enumerate()
.map(|(i, _)| format!("Dataset {}", i + 1))
.collect()
} else if cli.labels.len() != cli.csv.len() {
return Err(format!(
"Number of labels ({}) must match number of CSV files ({})",
cli.labels.len(),
cli.csv.len()
)
.into());
} else {
cli.labels.clone()
};
// Load all CSVs and add dataset label
let mut dfs = Vec::new();
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
let mut df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
.finish()?;
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
df.with_column(label_series)?;
dfs.push(df);
}
// Combine all dataframes
let df = dfs
.into_iter()
.reduce(|acc, df| acc.vstack(&df).unwrap())
.unwrap();
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
let groups = items.group;
// Prepare data for plotting.
let all_data: Vec<Measurement> = rdr
.deserialize::<Measurement>()
.collect::<Result<Vec<_>, _>>()?;
for group in groups {
if group.protocol_latency.is_some() {
let latency = group.protocol_latency.unwrap();
plot_runtime_vs(
&all_data,
cli.min_max_band,
&group.name,
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
"runtime_vs_bandwidth.html",
"Bandwidth (Mbps)",
)?;
// Determine which field varies in benches for this group
let benches_in_group: Vec<_> = items
.bench
.iter()
.filter(|b| b.group.as_deref() == Some(&group.name))
.collect();
if benches_in_group.is_empty() {
continue;
}
if group.bandwidth.is_some() {
let bandwidth = group.bandwidth.unwrap();
// Check which field has varying values
let bandwidth_varies = benches_in_group
.windows(2)
.any(|w| w[0].bandwidth != w[1].bandwidth);
let latency_varies = benches_in_group
.windows(2)
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
let download_size_varies = benches_in_group
.windows(2)
.any(|w| w[0].download_size != w[1].download_size);
if download_size_varies {
let upload_size = group.upload_size.unwrap_or(1024);
plot_runtime_vs(
&all_data,
&df,
&labels,
cli.min_max_band,
&group.name,
|r| r.latency as f32,
"download_size",
1.0 / 1024.0, // bytes to KB
"Runtime vs Response Size",
format!("{} bytes upload size", upload_size),
"runtime_vs_download_size",
"Response Size (KB)",
true, // legend on left
)?;
} else if bandwidth_varies {
let latency = group.protocol_latency.unwrap_or(50);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"bandwidth",
1.0 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency", latency),
"runtime_vs_bandwidth",
"Bandwidth (Mbps)",
false, // legend on right
)?;
} else if latency_varies {
let bandwidth = group.bandwidth.unwrap_or(1000);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"latency",
1.0,
"Runtime vs Latency",
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
"runtime_vs_latency.html",
format!("{} bps bandwidth", bandwidth),
"runtime_vs_latency",
"Latency (ms)",
true, // legend on left
)?;
}
}
@@ -92,83 +152,51 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
struct DataPoint {
min: f32,
mean: f32,
max: f32,
}
struct Points {
preprocess: DataPoint,
online: DataPoint,
total: DataPoint,
}
#[allow(clippy::too_many_arguments)]
fn plot_runtime_vs<Fx>(
all_data: &[Measurement],
fn plot_runtime_vs(
df: &DataFrame,
labels: &[String],
show_min_max: bool,
group: &str,
x_value: Fx,
x_col: &str,
x_scale: f32,
title: &str,
subtitle: String,
output_file: &str,
x_axis_label: &str,
) -> Result<Chart, Box<dyn std::error::Error>>
where
Fx: Fn(&Measurement) -> f32,
{
fn data_point(values: &[f32]) -> DataPoint {
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
DataPoint { min, mean, max }
}
legend_left: bool,
) -> Result<Chart, Box<dyn std::error::Error>> {
let stats_df = df
.clone()
.lazy()
.filter(col("group").eq(lit(group)))
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
.with_columns([
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
])
.group_by([col("x"), col("dataset_label")])
.agg([
col("preprocess").min().alias("preprocess_min"),
col("preprocess").mean().alias("preprocess_mean"),
col("preprocess").max().alias("preprocess_max"),
col("online").min().alias("online_min"),
col("online").mean().alias("online_mean"),
col("online").max().alias("online_max"),
col("total").min().alias("total_min"),
col("total").mean().alias("total_mean"),
col("total").max().alias("total_max"),
])
.sort(["dataset_label", "x"], Default::default())
.collect()?;
let stats: Vec<(f32, Points)> = all_data
.iter()
.filter(|r| r.group.as_deref() == Some(group))
.map(|r| {
(
x_value(r),
r.time_preprocess as f32 / 1000.0, // ms to s
r.time_online as f32 / 1000.0,
r.time_total as f32 / 1000.0,
)
})
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.chunk_by(|entry| entry.0)
.into_iter()
.map(|(x, group)| {
let group_vec: Vec<_> = group.collect();
let preprocess = data_point(
&group_vec
.iter()
.map(|(_, t, _, _)| *t)
.collect::<Vec<f32>>(),
);
let online = data_point(
&group_vec
.iter()
.map(|(_, _, t, _)| *t)
.collect::<Vec<f32>>(),
);
let total = data_point(
&group_vec
.iter()
.map(|(_, _, _, t)| *t)
.collect::<Vec<f32>>(),
);
(
x,
Points {
preprocess,
online,
total,
},
)
})
.collect();
// Build legend entries
let mut legend_data = Vec::new();
for label in labels {
legend_data.push(format!("Total Mean ({})", label));
legend_data.push(format!("Online Mean ({})", label));
}
let mut chart = Chart::new()
.title(
@@ -179,14 +207,6 @@ where
.subtext_style(TextStyle::new().font_size(16)),
)
.tooltip(Tooltip::new().trigger(Trigger::Axis))
.legend(
Legend::new()
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
.top("80")
.right("110")
.orient(Orient::Vertical)
.item_gap(10),
)
.x_axis(
Axis::new()
.name(x_axis_label)
@@ -205,73 +225,156 @@ where
.name_text_style(TextStyle::new().font_size(21)),
);
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
// Add legend with conditional positioning
let legend = Legend::new()
.data(legend_data)
.top("80")
.orient(Orient::Vertical)
.item_gap(10);
if show_min_max {
chart = add_min_max_band(
chart,
&stats,
"Preprocess Min/Max",
|p| &p.preprocess,
"#ccc",
);
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
let legend = if legend_left {
legend.left("110")
} else {
legend.right("110")
};
chart = chart.legend(legend);
// Define colors for each dataset
let colors = vec![
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
];
for (idx, label) in labels.iter().enumerate() {
let color = colors.get(idx % colors.len()).unwrap();
// Total time - solid line
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Total Mean ({})", label),
"total_mean",
false,
color,
)?;
// Online time - dashed line (same color as total)
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Online Mean ({})", label),
"online_mean",
true,
color,
)?;
if show_min_max {
chart = add_dataset_min_max_band(
&chart,
&stats_df,
label,
&format!("Total Min/Max ({})", label),
"total",
color,
)?;
}
}
// Save the chart as HTML file.
// Save the chart as HTML file (no theme)
HtmlRenderer::new(title, 1000, 800)
.theme(THEME)
.save(&chart, output_file)
.save(&chart, &format!("{}.html", output_file))
.unwrap();
// Save SVG with default theme
ImageRenderer::new(1000, 800)
.theme(Theme::Default)
.save(&chart, &format!("{}.svg", output_file))
.unwrap();
// Save SVG with dark theme
ImageRenderer::new(1000, 800)
.theme(Theme::Dark)
.save(&chart, &format!("{}_dark.svg", output_file))
.unwrap();
Ok(chart)
}
fn add_mean_series(
chart: Chart,
stats: &[(f32, Points)],
name: &str,
extract: impl Fn(&Points) -> f32,
) -> Chart {
chart.series(
Line::new()
.name(name)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points)])
.collect(),
)
.symbol_size(6),
)
fn add_dataset_series(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
series_name: &str,
col_name: &str,
dashed: bool,
color: &str,
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let y = filtered.column(col_name)?.f32()?;
let data: Vec<Vec<f32>> = x
.into_iter()
.zip(y.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let mut line = Line::new()
.name(series_name)
.data(data)
.symbol_size(6)
.item_style(ItemStyle::new().color(color));
let mut line_style = LineStyle::new();
if dashed {
line_style = line_style.type_(LineStyleType::Dashed);
}
line = line.line_style(line_style.color(color));
Ok(chart.clone().series(line))
}
fn add_min_max_band(
chart: Chart,
stats: &[(f32, Points)],
fn add_dataset_min_max_band(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
name: &str,
extract: impl Fn(&Points) -> &DataPoint,
col_prefix: &str,
color: &str,
) -> Chart {
chart.series(
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
let max_data: Vec<Vec<f32>> = x
.into_iter()
.zip(max_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let min_data: Vec<Vec<f32>> = x
.into_iter()
.zip(min_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.rev()
.collect();
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
Ok(chart.clone().series(
Line::new()
.name(name)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points).max])
.chain(
stats
.iter()
.rev()
.map(|(x, points)| vec![*x, extract(points).min]),
)
.collect(),
)
.data(data)
.show_symbol(false)
.line_style(LineStyle::new().opacity(0.0))
.area_style(AreaStyle::new().opacity(0.3).color(color)),
)
))
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -22,7 +22,6 @@ clap = { workspace = true, features = ["derive", "env"] }
csv = { version = "1.3" }
duct = { version = "1" }
futures = { workspace = true }
indicatif = { version = "0.17" }
ipnet = { workspace = true }
serio = { workspace = true }
serde_json = { workspace = true }

View File

@@ -16,10 +16,6 @@ pub struct Cli {
/// Subnet to assign harness network interfaces.
#[arg(long, default_value = "10.250.0.0/24", env = "SUBNET")]
pub subnet: Ipv4Net,
/// Run browser in headed mode (visible window) for debugging.
/// Works with both X11 and Wayland.
#[arg(long)]
pub headed: bool,
}
#[derive(Subcommand)]
@@ -35,13 +31,10 @@ pub enum Command {
},
/// runs benchmarks.
Bench {
/// Configuration path. Defaults to bench.toml which contains
/// representative scenarios (cable, 5G, fiber) for quick performance
/// checks. Use bench_*_sweep.toml files for parametric
/// analysis.
/// Configuration path.
#[arg(short, long, default_value = "bench.toml")]
config: PathBuf,
/// Output CSV file path for detailed metrics and post-processing.
/// Output file path.
#[arg(short, long, default_value = "metrics.csv")]
output: PathBuf,
/// Number of samples to measure per benchmark. This is overridden by

View File

@@ -28,9 +28,6 @@ pub struct Executor {
ns: Namespace,
config: ExecutorConfig,
target: Target,
/// Display environment variables for headed mode (X11/Wayland).
/// Empty means headless mode.
display_env: Vec<String>,
state: State,
}
@@ -52,17 +49,11 @@ impl State {
}
impl Executor {
pub fn new(
ns: Namespace,
config: ExecutorConfig,
target: Target,
display_env: Vec<String>,
) -> Self {
pub fn new(ns: Namespace, config: ExecutorConfig, target: Target) -> Self {
Self {
ns,
config,
target,
display_env,
state: State::Init,
}
}
@@ -129,49 +120,23 @@ impl Executor {
let tmp = duct::cmd!("mktemp", "-d").read()?;
let tmp = tmp.trim();
let headed = !self.display_env.is_empty();
// Build command args based on headed/headless mode
let mut args: Vec<String> = vec![
"ip".into(),
"netns".into(),
"exec".into(),
self.ns.name().into(),
];
if headed {
// For headed mode: drop back to the current user and pass display env vars
// This allows the browser to connect to X11/Wayland while in the namespace
let user =
std::env::var("USER").context("USER environment variable not set")?;
args.extend(["sudo".into(), "-E".into(), "-u".into(), user, "env".into()]);
args.extend(self.display_env.clone());
}
args.push(chrome_path.to_string_lossy().into());
args.push(format!("--remote-debugging-port={PORT_BROWSER}"));
if headed {
// Headed mode: no headless, add flags to suppress first-run dialogs
args.extend(["--no-first-run".into(), "--no-default-browser-check".into()]);
} else {
// Headless mode: original flags
args.extend([
"--headless".into(),
"--disable-dev-shm-usage".into(),
"--disable-gpu".into(),
"--disable-cache".into(),
"--disable-application-cache".into(),
]);
}
args.extend([
"--no-sandbox".into(),
let process = duct::cmd!(
"sudo",
"ip",
"netns",
"exec",
self.ns.name(),
chrome_path,
format!("--remote-debugging-port={PORT_BROWSER}"),
"--headless",
"--disable-dev-shm-usage",
"--disable-gpu",
"--disable-cache",
"--disable-application-cache",
"--no-sandbox",
format!("--user-data-dir={tmp}"),
"--allowed-ips=10.250.0.1".into(),
]);
let process = duct::cmd("sudo", &args);
format!("--allowed-ips=10.250.0.1"),
);
let process = if !cfg!(feature = "debug") {
process.stderr_capture().stdout_capture().start()?

View File

@@ -9,7 +9,7 @@ mod ws_proxy;
#[cfg(feature = "debug")]
mod debug_prelude;
use std::{collections::HashMap, time::Duration};
use std::time::Duration;
use anyhow::Result;
use clap::Parser;
@@ -22,7 +22,6 @@ use harness_core::{
rpc::{BenchCmd, TestCmd},
test::TestStatus,
};
use indicatif::{ProgressBar, ProgressStyle};
use cli::{Cli, Command};
use executor::Executor;
@@ -33,60 +32,6 @@ use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
/// Statistics for a benchmark configuration
#[derive(Debug, Clone)]
struct BenchStats {
group: Option<String>,
bandwidth: usize,
latency: usize,
upload_size: usize,
download_size: usize,
times: Vec<u64>,
}
impl BenchStats {
fn median(&self) -> f64 {
let mut sorted = self.times.clone();
sorted.sort();
let len = sorted.len();
if len == 0 {
return 0.0;
}
if len.is_multiple_of(2) {
(sorted[len / 2 - 1] + sorted[len / 2]) as f64 / 2.0
} else {
sorted[len / 2] as f64
}
}
}
/// Print summary table of benchmark results
fn print_bench_summary(stats: &[BenchStats]) {
if stats.is_empty() {
println!("\nNo benchmark results to display (only warmup was run).");
return;
}
println!("\n{}", "=".repeat(80));
println!("TLSNotary Benchmark Results");
println!("{}", "=".repeat(80));
println!();
for stat in stats {
let group_name = stat.group.as_deref().unwrap_or("unnamed");
println!(
"{} ({} Mbps, {}ms latency, {}KB↑ {}KB↓):",
group_name,
stat.bandwidth,
stat.latency,
stat.upload_size / 1024,
stat.download_size / 1024
);
println!(" Median: {:.2}s", stat.median() / 1000.0);
println!();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
pub enum Target {
#[default]
@@ -105,46 +50,14 @@ struct Runner {
started: bool,
}
/// Collects display-related environment variables for headed browser mode.
/// Works with both X11 and Wayland by collecting whichever vars are present.
fn collect_display_env_vars() -> Vec<String> {
const DISPLAY_VARS: &[&str] = &[
"DISPLAY", // X11
"XAUTHORITY", // X11 auth
"WAYLAND_DISPLAY", // Wayland
"XDG_RUNTIME_DIR", // Wayland runtime dir
];
DISPLAY_VARS
.iter()
.filter_map(|&var| {
std::env::var(var)
.ok()
.map(|val| format!("{}={}", var, val))
})
.collect()
}
impl Runner {
fn new(cli: &Cli) -> Result<Self> {
let Cli {
target,
subnet,
headed,
..
} = cli;
let Cli { target, subnet, .. } = cli;
let current_path = std::env::current_exe().unwrap();
let fixture_path = current_path.parent().unwrap().join("server-fixture");
let network_config = NetworkConfig::new(*subnet);
let network = Network::new(network_config.clone())?;
// Collect display env vars once if headed mode is enabled
let display_env = if *headed {
collect_display_env_vars()
} else {
Vec::new()
};
let server_fixture =
ServerFixture::new(fixture_path, network.ns_app().clone(), network_config.app);
let wasm_server = WasmServer::new(
@@ -162,7 +75,6 @@ impl Runner {
.network_config(network_config.clone())
.build(),
*target,
display_env.clone(),
);
let exec_v = Executor::new(
network.ns_1().clone(),
@@ -172,7 +84,6 @@ impl Runner {
.network_config(network_config.clone())
.build(),
Target::Native,
Vec::new(), // Verifier doesn't need display env
);
Ok(Self {
@@ -207,12 +118,6 @@ pub async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
// Validate --headed requires --target browser
if cli.headed && cli.target != Target::Browser {
anyhow::bail!("--headed can only be used with --target browser");
}
let mut runner = Runner::new(&cli)?;
let mut exit_code = 0;
@@ -301,12 +206,6 @@ pub async fn main() -> Result<()> {
samples_override,
skip_warmup,
} => {
// Print configuration info
println!("TLSNotary Benchmark Harness");
println!("Running benchmarks from: {}", config.display());
println!("Output will be written to: {}", output.display());
println!();
let items: BenchItems = toml::from_str(&std::fs::read_to_string(config)?)?;
let output_file = std::fs::File::create(output)?;
let mut writer = WriterBuilder::new().from_writer(output_file);
@@ -321,34 +220,7 @@ pub async fn main() -> Result<()> {
runner.exec_p.start().await?;
runner.exec_v.start().await?;
// Create progress bar
let pb = ProgressBar::new(benches.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
.expect("valid template")
.progress_chars("█▓▒░ "),
);
// Collect measurements for stats
let mut measurements_by_config: HashMap<String, Vec<u64>> = HashMap::new();
let warmup_count = if skip_warmup { 0 } else { 3 };
for (idx, config) in benches.iter().enumerate() {
let is_warmup = idx < warmup_count;
let group_name = if is_warmup {
format!("Warmup {}/{}", idx + 1, warmup_count)
} else {
config.group.as_deref().unwrap_or("unnamed").to_string()
};
pb.set_message(format!(
"{} ({} Mbps, {}ms)",
group_name, config.bandwidth, config.protocol_latency
));
for config in benches {
runner
.network
.set_proto_config(config.bandwidth, config.protocol_latency.div_ceil(2))?;
@@ -377,73 +249,11 @@ pub async fn main() -> Result<()> {
panic!("expected prover output");
};
// Collect metrics for stats (skip warmup benches)
if !is_warmup {
let config_key = format!(
"{:?}|{}|{}|{}|{}",
config.group,
config.bandwidth,
config.protocol_latency,
config.upload_size,
config.download_size
);
measurements_by_config
.entry(config_key)
.or_default()
.push(metrics.time_total);
}
let measurement = Measurement::new(config.clone(), metrics);
let measurement = Measurement::new(config, metrics);
writer.serialize(measurement)?;
writer.flush()?;
pb.inc(1);
}
pb.finish_with_message("Benchmarks complete");
// Compute and print statistics
let mut all_stats: Vec<BenchStats> = Vec::new();
for (key, times) in measurements_by_config {
// Parse back the config from the key
let parts: Vec<&str> = key.split('|').collect();
if parts.len() >= 5 {
let group = if parts[0] == "None" {
None
} else {
Some(
parts[0]
.trim_start_matches("Some(\"")
.trim_end_matches("\")")
.to_string(),
)
};
let bandwidth: usize = parts[1].parse().unwrap_or(0);
let latency: usize = parts[2].parse().unwrap_or(0);
let upload_size: usize = parts[3].parse().unwrap_or(0);
let download_size: usize = parts[4].parse().unwrap_or(0);
all_stats.push(BenchStats {
group,
bandwidth,
latency,
upload_size,
download_size,
times,
});
}
}
// Sort stats by group name for consistent output
all_stats.sort_by(|a, b| {
a.group
.cmp(&b.group)
.then(a.latency.cmp(&b.latency))
.then(a.bandwidth.cmp(&b.bandwidth))
});
print_bench_summary(&all_stats);
}
Command::Serve {} => {
runner.start_services().await?;

View File

@@ -0,0 +1,25 @@
#### Bandwidth ####
[[group]]
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
[[bench]]
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -0,0 +1,37 @@
[[group]]
name = "download_size"
protocol_latency = 10
bandwidth = 200
upload-size = 2048
[[bench]]
group = "download_size"
download-size = 1024
[[bench]]
group = "download_size"
download-size = 2048
[[bench]]
group = "download_size"
download-size = 4096
[[bench]]
group = "download_size"
download-size = 8192
[[bench]]
group = "download_size"
download-size = 16384
[[bench]]
group = "download_size"
download-size = 32768
[[bench]]
group = "download_size"
download-size = 65536
[[bench]]
group = "download_size"
download-size = 131072

View File

@@ -0,0 +1,25 @@
#### Latency ####
[[group]]
name = "latency"
bandwidth = 1000
[[bench]]
group = "latency"
protocol_latency = 10
[[bench]]
group = "latency"
protocol_latency = 25
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200

View File

@@ -66,6 +66,7 @@ rand_chacha = { workspace = true }
rstest = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }

View File

@@ -378,29 +378,15 @@ impl MpcTlsLeader {
Ok(())
}
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
/// Defers decryption of any incoming messages.
#[instrument(level = "debug", skip_all, err)]
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = enable;
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
self.is_decrypting = false;
self.notifier.clear();
Ok(())
}
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop();

View File

@@ -32,14 +32,10 @@ impl MpcTlsLeaderCtrl {
Self { address }
}
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
/// Defers decryption of any incoming messages.
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> {
self.address
.send(EnableDecryption { enable })
.send(DeferDecryption)
.await
.map_err(MpcTlsError::actor)?
}
@@ -985,7 +981,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
}
}
impl Dispatch<MpcTlsLeader> for EnableDecryption {
impl Dispatch<MpcTlsLeader> for DeferDecryption {
fn dispatch<R: FnOnce(Self::Return) + Send>(
self,
actor: &mut MpcTlsLeader,
@@ -996,13 +992,13 @@ impl Dispatch<MpcTlsLeader> for EnableDecryption {
}
}
impl Handler<EnableDecryption> for MpcTlsLeader {
impl Handler<DeferDecryption> for MpcTlsLeader {
async fn handle(
&mut self,
msg: EnableDecryption,
_msg: DeferDecryption,
_ctx: &mut LudiCtx<Self>,
) -> <EnableDecryption as Message>::Return {
self.enable_decryption(msg.enable)
) -> <DeferDecryption as Message>::Return {
self.defer_decryption().await
}
}
@@ -1052,7 +1048,7 @@ pub enum MpcTlsLeaderMsg {
BackendMsgGetNotify(BackendMsgGetNotify),
BackendMsgIsEmpty(BackendMsgIsEmpty),
BackendMsgServerClosed(BackendMsgServerClosed),
DeferDecryption(EnableDecryption),
DeferDecryption(DeferDecryption),
Stop(Stop),
}
@@ -1087,7 +1083,7 @@ pub enum MpcTlsLeaderMsgReturn {
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
DeferDecryption(<EnableDecryption as Message>::Return),
DeferDecryption(<DeferDecryption as Message>::Return),
Stop(<Stop as Message>::Return),
}
@@ -1736,25 +1732,23 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
}
}
/// Message to enable or disable the decryption of messages.
/// Message to start deferring the decryption.
#[allow(missing_docs)]
#[derive(Debug)]
pub struct EnableDecryption {
pub enable: bool,
}
pub struct DeferDecryption;
impl Message for EnableDecryption {
impl Message for DeferDecryption {
type Return = Result<(), MpcTlsError>;
}
impl From<EnableDecryption> for MpcTlsLeaderMsg {
fn from(value: EnableDecryption) -> Self {
impl From<DeferDecryption> for MpcTlsLeaderMsg {
fn from(value: DeferDecryption) -> Self {
MpcTlsLeaderMsg::DeferDecryption(value)
}
}
impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg {
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> {
match ret {
Self::Return::DeferDecryption(value) => Ok(value),
_ => Err(Error::Wrapper),

View File

@@ -0,0 +1,160 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.max_recv(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
rcot_recv_a,
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b,
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
);
(leader, follower)
}

View File

@@ -0,0 +1,39 @@
[package]
name = "tlsn-tls-client-async"
authors = ["TLSNotary Team"]
description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "tls_client_async"
[features]
default = ["tracing"]
tracing = ["dep:tracing"]
[dependencies]
tlsn-tls-client = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tokio-util = { workspace = true, features = ["io", "compat"] }
tracing = { workspace = true, optional = true }
[dev-dependencies]
tls-server-fixture = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -0,0 +1,89 @@
use bytes::Bytes;
use futures::{
channel::mpsc::{Receiver, SendError, Sender},
sink::SinkMapErr,
AsyncRead, AsyncWrite, SinkExt,
};
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
pin::Pin,
task::{Context, Poll},
};
use tokio_util::{
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
io::{CopyToBytes, SinkWriter, StreamReader},
};
type CompatSinkWriter =
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
/// A TLS connection to a server.
///
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
#[derive(Debug)]
pub struct TlsConnection {
/// The data to be transmitted to the server is sent to this sink.
tx_sender: CompatSinkWriter,
/// The data to be received from the server is received from this stream.
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
}
impl TlsConnection {
/// Creates a new TLS connection.
pub(crate) fn new(
tx_sender: Sender<Bytes>,
rx_receiver: Receiver<Result<Bytes, IoError>>,
) -> Self {
fn convert_error(err: SendError) -> IoError {
if err.is_disconnected() {
IoErrorKind::BrokenPipe.into()
} else {
IoErrorKind::WouldBlock.into()
}
}
Self {
tx_sender: SinkWriter::new(CopyToBytes::new(
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
))
.compat_write(),
rx_receiver: StreamReader::new(rx_receiver).compat(),
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_close(cx)
}
}

View File

@@ -0,0 +1,269 @@
//! Provides a TLS client which exposes an async socket.
//!
//! This library provides the [bind_client] function which attaches a TLS client
//! to a socket connection and then exposes a [TlsConnection] object, which
//! provides an async socket API for reading and writing cleartext. The TLS
//! client will then automatically encrypt and decrypt traffic and forward that
//! to the provided socket.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod conn;
use bytes::{Buf, Bytes};
use futures::{
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, trace, warn, Instrument};
use tls_client::ClientConnection;
pub use conn::TlsConnection;
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
/// An error that can occur during a TLS connection.
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
TlsError(#[from] tls_client::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
}
/// Closed connection data.
#[derive(Debug)]
pub struct ClosedConnection {
/// The connection for the client
pub client: ClientConnection,
/// Sent plaintext bytes
pub sent: Vec<u8>,
/// Received plaintext bytes
pub recv: Vec<u8>,
}
/// A future which runs the TLS connection to completion.
///
/// This future must be polled in order for the connection to make progress.
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture {
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
}
impl Future for ConnectionFuture {
type Output = Result<ClosedConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
/// Binds a client connection to the provided socket.
///
/// Returns a connection handle and a future which runs the connection to
/// completion.
///
/// # Errors
///
/// Any connection errors that occur will be returned from the future, not
/// [`TlsConnection`].
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
socket: T,
mut client: ClientConnection,
) -> (TlsConnection, ConnectionFuture) {
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
let conn = TlsConnection::new(tx_sender, rx_receiver);
let fut = async move {
client.start().await?;
let mut notify = client.get_notify().await?;
let (mut server_rx, mut server_tx) = socket.split();
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
let mut rx_buf = [0u8; RX_BUF_SIZE];
let mut handshake_done = false;
let mut client_closed = false;
let mut server_closed = false;
let mut sent = Vec::with_capacity(1024);
let mut recv = Vec::with_capacity(1024);
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
// We don't start writing application data until the handshake is complete.
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
// Runs both the tx and rx halves of the connection to completion.
// This loop does not terminate until the *SERVER* closes the connection and
// we've processed all received data. If an error occurs, the `TlsConnection`
// channels will be closed and the error will be returned from this future.
'conn: loop {
// Write all pending TLS data to the server.
if client.wants_write() && !client_closed {
#[cfg(feature = "tracing")]
trace!("client wants to write");
while client.wants_write() {
let _sent = client.write_tls_async(&mut server_tx).await?;
#[cfg(feature = "tracing")]
trace!("sent {} tls bytes to server", _sent);
}
server_tx.flush().await?;
}
// Forward received plaintext to `TlsConnection`.
while !client.plaintext_is_empty() {
let read = client.read_plaintext(&mut rx_buf)?;
recv.extend(&rx_buf[..read]);
// Ignore if the receiver has hung up.
_ = rx_sender
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
.await;
#[cfg(feature = "tracing")]
trace!("forwarded {} plaintext bytes to conn", read);
}
if !client.is_handshaking() && !handshake_done {
#[cfg(feature = "tracing")]
debug!("handshake complete");
handshake_done = true;
// Start reading application data that needs to be transmitted from the
// `TlsConnection`.
tx_recv_fut = tx_receiver.next().fuse();
}
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
break 'conn;
}
select_biased! {
// Reads TLS data from the server and writes it into the client.
received = &mut rx_tls_fut => {
let received = received?;
#[cfg(feature = "tracing")]
trace!("received {} tls bytes from server", received);
// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
let mut reader = rx_tls_buf[..received].reader();
loop {
processed += client.read_tls(&mut reader)?;
client.process_new_packets().await?;
debug_assert!(processed <= received);
if processed >= received {
break;
}
}
#[cfg(feature = "tracing")]
trace!("processed {} tls bytes from server", processed);
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket.
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
server_closed = true;
client.server_closed().await?;
// Do not read from the socket again.
rx_tls_fut = Fuse::terminated();
} else {
// Reset the read future so next iteration we can read again.
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
}
}
// If we receive None from `TlsConnection`, it has closed, so we
// send a close_notify to the server.
data = &mut tx_recv_fut => {
if let Some(data) = data {
#[cfg(feature = "tracing")]
trace!("writing {} plaintext bytes to client", data.len());
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {
if !server_closed {
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
#[cfg(feature = "tracing")]
warn!("failed to send close_notify to server: {}", e);
}
}
client_closed = true;
tx_recv_fut = Fuse::terminated();
}
}
// Waits for a notification from the backend that it is ready to decrypt data.
_ = &mut notify => {
#[cfg(feature = "tracing")]
trace!("backend is ready to decrypt");
client.process_new_packets().await?;
}
}
}
#[cfg(feature = "tracing")]
debug!("client shutdown");
_ = server_tx.close().await;
tx_receiver.close();
rx_sender.close_channel();
#[cfg(feature = "tracing")]
trace!(
"server close notify: {}, sent: {}, recv: {}",
client.received_close_notify(),
sent.len(),
recv.len()
);
Ok(ClosedConnection { client, sent, recv })
};
#[cfg(feature = "tracing")]
let fut = fut.instrument(debug_span!("tls_connection"));
let fut = ConnectionFuture { fut: Box::pin(fut) };
(conn, fut)
}
async fn send_close_notify(
client: &mut ClientConnection,
server_tx: &mut (impl AsyncWrite + Unpin),
) -> Result<(), ConnectionError> {
#[cfg(feature = "tracing")]
trace!("sending close_notify to server");
client.send_close_notify().await?;
client.process_new_packets().await?;
// Flush all remaining plaintext
while client.wants_write() {
client.write_tls_async(server_tx).await?;
}
server_tx.flush().await?;
Ok(())
}

View File

@@ -0,0 +1,438 @@
use std::{str, sync::Arc};
use core::future::Future;
use futures::{AsyncReadExt, AsyncWriteExt};
use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
SERVER_DOMAIN,
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
client_tls_conn: TlsConnection,
// a handle that must be `.await`ed to get the result of a TLS connection
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
}
// Sets up a TLS connection between client and server and sends a hello message
#[fixture]
async fn set_up_tls() -> TlsFixture {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
client_tls_conn
.write_all(&pad("expecting you to send back hello".to_string()))
.await
.unwrap();
// give the server some time to respond
std::thread::sleep(std::time::Duration::from_millis(10));
let mut plaintext = vec![0u8; 320];
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
let s = str::from_utf8(&plaintext[0..n]).unwrap();
assert_eq!(s, "hello");
TlsFixture {
client_tls_conn,
closed_tls_task,
}
}
// Expect the async tls client wrapped in `hyper::client` to make a successful
// request and receive the expected response
#[tokio::test]
async fn test_hyper_ok() {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
.await
.unwrap();
tokio::spawn(connection);
let request = Request::builder()
.uri(format!("https://{SERVER_DOMAIN}/echo"))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Full::<Bytes>::new("hello".into()))
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Process the response body
response.into_body().collect().await.unwrap().to_bytes();
let _ = server_task.await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us AND close the socket
// after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify_and_socket_close(
set_up_tls: impl Future<Output = TlsFixture>,
) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect to be able to read the data after server closes the socket abruptly
#[rstest]
#[tokio::test]
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to send us a hello message
client_tls_conn
.write_all(&pad("send a hello message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to read some more data
let mut buf = vec![0u8; 10];
let n = client_tls_conn.read(&mut buf).await.unwrap();
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
}
// Expect there to be no error when server DOES NOT send close_notify but just
// closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(!closed_conn.client.received_close_notify());
}
// Expect to register a delay when the server delays closing the socket
#[rstest]
#[tokio::test]
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
client_tls_conn
.write_all(&pad("must_delay_when_closing".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client
client_tls_conn.close().await.unwrap();
use std::time::Instant;
let now = Instant::now();
// this will resolve when the server stops delaying closing the socket
let closed_conn = closed_tls_task.await.unwrap().unwrap();
let elapsed = now.elapsed();
// the elapsed time must be roughly equal to the server's delay
// (give or take timing variations)
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
assert!(!closed_conn.client.received_close_notify());
}
// Expect client to error when server sends a corrupted message
#[rstest]
#[tokio::test]
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send a corrupted message
client_tls_conn
.write_all(&pad("send_corrupted_message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received corrupt message"
);
}
// Expect client to error when server sends a TLS record with a bad MAC
#[rstest]
#[tokio::test]
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_record_with_bad_mac".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"backend error: Decryption error: \"aead::Error\""
);
}
// Expect client to error when server sends a fatal alert
#[rstest]
#[tokio::test]
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_alert".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received fatal alert: BadRecordMac"
);
}
// Expect an error when trying to write data to a connection which server closed
// abruptly
#[rstest]
#[tokio::test]
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to send some more data
let res = client_tls_conn
.write_all(&pad("more data".to_string()))
.await;
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
}
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
fn pad(s: String) -> Vec<u8> {
assert!(s.len() <= APP_RECORD_LENGTH);
let mut buf = vec![0u8; APP_RECORD_LENGTH];
buf[..s.len()].copy_from_slice(s.as_bytes());
buf
}

View File

@@ -457,9 +457,6 @@ impl ConnectionCommon {
return Err(Error::CorruptMessage);
}
// Process outgoing plaintext buffer and encrypt messages.
self.flush_plaintext().await?;
// Process new messages.
while let Some(msg) = self.message_deframer.frames.pop_front() {
// If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -511,22 +508,25 @@ impl ConnectionCommon {
Ok(state)
}
/// Writes plaintext `buf` into an internal buffer. May not fully process the
/// whole buffer and returns the processed length.
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if buf.is_empty() {
// Don't send empty fragments.
return Ok(0);
/// Write buffer into connection.
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state).await;
}
let len = self.sendable_plaintext.append_limited_copy(buf);
Ok(len)
self.common_state.send_some_plaintext(buf).await
}
/// Writes the entire plaintext `buf` into an internal buffer.
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
self.sendable_plaintext.append(buf.to_vec());
Ok(())
/// Write entire buffer into connection.
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
while pos < buf.len() {
pos += self.write_plaintext(&buf[pos..]).await?;
}
self.backend.flush().await?;
while let Some(msg) = self.backend.next_outgoing().await? {
self.queue_tls_message(msg);
}
Ok(pos)
}
/// Read TLS content from `rd`. This method does internal
@@ -690,11 +690,6 @@ impl CommonState {
self.received_plaintext.is_empty()
}
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS
/// handshake.
///
@@ -787,6 +782,15 @@ impl CommonState {
}
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake
// messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections,
@@ -927,6 +931,32 @@ impl CommonState {
self.sendable_tls.write_to_async(wr).await
}
/// Encrypt and send some plaintext `data`. `limit` controls
/// whether the per-connection buffer limits apply.
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
let len = match limit {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true;
self.flush_plaintext().await
@@ -982,14 +1012,15 @@ impl CommonState {
self.sendable_tls.set_limit(limit);
}
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data {
return Ok(());
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_appdata_encrypt(&buf, Limit::No).await?;
self.send_plain(&buf, Limit::No).await?;
}
Ok(())

View File

@@ -35,15 +35,6 @@ impl ChunkVecBuffer {
self.chunks.is_empty()
}
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing
pub(crate) fn len(&self) -> usize {
let mut len = 0;

View File

@@ -247,8 +247,7 @@ async fn servered_client_data_sent() {
let (mut client, mut server) =
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
client.flush_plaintext().await.unwrap();
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -287,7 +286,7 @@ async fn servered_both_data_sent() {
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
do_handshake(&mut client, &mut server).await;
@@ -433,7 +432,7 @@ async fn server_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
server.send_close_notify();
receive(&mut server, &mut client);
@@ -461,8 +460,7 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.flush_plaintext().await.unwrap();
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
client.send_close_notify().await.unwrap();
send(&mut client, &mut server);
@@ -489,7 +487,7 @@ async fn server_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
receive(&mut server, &mut client);
transfer_eof(&mut client);
@@ -520,7 +518,7 @@ async fn client_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
client.process_new_packets().await.unwrap();
send(&mut client, &mut server);
@@ -902,9 +900,20 @@ async fn client_respects_buffer_limit_pre_handshake() {
client.set_buffer_limit(Some(32));
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
client.flush_plaintext().await.unwrap();
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
12
);
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -944,9 +953,20 @@ async fn client_respects_buffer_limit_post_handshake() {
do_handshake(&mut client, &mut server).await;
client.set_buffer_limit(Some(48));
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
client.flush_plaintext().await.unwrap();
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
6
);
send(&mut client, &mut server);
server.process_new_packets().unwrap();
@@ -1191,8 +1211,14 @@ async fn client_complete_io_for_write() {
do_handshake(&mut client, &mut server).await;
client.write_plaintext(b"01234567890123456789").unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let (rdlen, wrlen) = client
@@ -1324,8 +1350,7 @@ async fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
client.write_all_plaintext(b"world").await.unwrap();
{
let mut pipe = ClientSession::new(&mut client);
@@ -1341,8 +1366,7 @@ async fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
client.write_all_plaintext(b"world").await.unwrap();
{
let pipe = ClientSession::new(&mut client);
@@ -1361,9 +1385,7 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 0,
// };
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// client.write_all_plaintext(b"hello").await.unwrap();
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert!(rc.is_err());
@@ -1380,9 +1402,7 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 1,
// };
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// client.write_all_plaintext(b"hello").await.unwrap();
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1880,9 +1900,14 @@ async fn servered_write_for_client_appdata() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
do_handshake(&mut client, &mut server).await;
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.process_new_packets().await.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -1994,10 +2019,11 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
async fn servered_write_for_client_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"0123456789").unwrap();
client.process_new_packets().await.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"0123456789").await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();

View File

@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
tlsn-core = { workspace = true }
tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true }
tlsn-cipher = { workspace = true }
futures-plex = { workspace = true }
serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true }
@@ -57,7 +57,6 @@ serde = { workspace = true, features = ["derive"] }
ghash = { workspace = true }
semver = { workspace = true, features = ["serde"] }
once_cell = { workspace = true }
pin-project-lite = { workspace = true }
rangeset = { workspace = true }
webpki-roots = { workspace = true }

View File

@@ -0,0 +1,21 @@
//! Execution context.
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub(crate) const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}

View File

@@ -4,6 +4,7 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub(crate) mod context;
pub(crate) mod ghash;
pub(crate) mod map;
pub(crate) mod mpz;
@@ -12,7 +13,6 @@ pub(crate) mod mux;
pub mod prover;
pub(crate) mod tag;
pub(crate) mod transcript_internal;
pub(crate) mod utils;
pub mod verifier;
pub use tlsn_attestation as attestation;
@@ -27,8 +27,6 @@ pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
});
const BUF_CAP: usize = 16 * 1024;
/// The party's role in the TLSN protocol.
///
/// A Notary is classified as a Verifier.

View File

@@ -1,38 +1,31 @@
//! Prover.
mod client;
mod conn;
mod control;
mod error;
mod future;
mod prove;
pub mod state;
pub use conn::TlsConnection;
pub use control::ProverControl;
pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput;
use crate::{
BUF_CAP, Role,
Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
prover::{
client::{MpcTlsClient, TlsOutput},
state::ConnectedProj,
},
utils::{CopyIo, await_with_copy_io, build_mt_context},
tag::verify_tags,
};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, TryFutureExt, ready};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
config::{
prove::ProveConfig,
@@ -43,9 +36,10 @@ use tlsn_core::{
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -77,27 +71,14 @@ impl Prover<state::Initialized> {
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `verifier_io` - The IO to the TLS verifier.
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
self,
config: TlsCommitConfig,
verifier_io: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
let fut = Box::pin(self.commit_inner(config, duplex_a).fuse());
let mut prover = await_with_copy_io(fut, verifier_io, &mut duplex_b).await?;
prover.state.verifier_io = Some(duplex_b);
Ok(prover)
}
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn commit_inner<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsCommitConfig,
verifier_io: S,
socket: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(verifier_io, Role::Prover);
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -141,7 +122,6 @@ impl Prover<state::Initialized> {
config: self.config,
span: self.span,
state: state::CommitAccepted {
verifier_io: None,
mux_ctrl,
mux_fut,
mpc_tls,
@@ -153,29 +133,31 @@ impl Prover<state::Initialized> {
}
impl Prover<state::CommitAccepted> {
/// Sets up the prover with the client configuration.
/// Connects to the server using the provided socket.
///
/// Returns a set up prover, and a [`TlsConnection`] which can be used to
/// read and write bytes from/to the server.
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed and the TLS transcript is
/// committed.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub fn setup(
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsClientConfig,
) -> Result<(TlsConnection, Prover<state::Setup>), ProverError> {
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::CommitAccepted {
verifier_io,
mux_ctrl,
mux_fut,
mut mux_fut,
mpc_tls,
keys,
vm,
..
} = self.state;
let decrypt = mpc_tls.is_decrypting();
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = config.server_name();
@@ -220,296 +202,95 @@ impl Prover<state::CommitAccepted> {
)
.map_err(ProverError::config)?;
let span = self.span.clone();
let (conn, conn_fut) = bind_client(socket, client);
let mpc_tls = MpcTlsClient::new(
Box::new(mpc_fut.map_err(ProverError::from)),
keys,
vm,
span,
mpc_ctrl,
client,
decrypt,
);
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let prover = Prover {
config: self.config,
span: self.span,
state: state::Setup {
mux_ctrl,
mux_fut,
server_name: config.server_name().clone(),
tls_client: Box::new(mpc_tls),
client_io: duplex_a,
verifier_io,
},
};
mpc_ctrl.stop().await?;
let conn = TlsConnection::new(duplex_b);
Ok((conn, prover))
}
}
Ok::<_, ProverError>(())
};
impl Prover<state::Setup> {
/// Returns a handle to control the prover.
pub fn handle(&self) -> ProverControl {
let handle = self.state.tls_client.handle();
ProverControl { handle }
}
info!("starting MPC-TLS");
/// Attaches IO to the prover.
///
/// # Arguments
///
/// * `server_io` - The IO to the server.
/// * `verifier_io` - The IO to the TLS verifier.
pub fn connect<S, T>(self, server_io: S, verifier_io: T) -> Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
let (client_to_server, server_to_client) = futures_plex::duplex(BUF_CAP);
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
Prover {
config: self.config,
span: self.span,
state: state::Connected {
verifier_io: self.state.verifier_io,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
server_name: self.state.server_name,
tls_client: self.state.tls_client,
client_io: self.state.client_io,
output: None,
server_socket: server_io,
verifier_socket: verifier_io,
tls_client_to_server_buf: client_to_server,
server_to_tls_client_buf: server_to_client,
client_closed: false,
server_closed: false,
},
}
}
info!("finished MPC-TLS");
/// This is a convenience method which attaches IO, runs the prover and
/// returns a committed prover together with the IO.
///
/// # Arguments
///
/// * `server_io` - The IO to the server.
/// * `verifier_io` - The IO to the TLS verifier.
pub async fn run<S, T>(
self,
mut server_io: S,
mut verifier_io: T,
) -> Result<(Prover<state::Committed>, S, T), ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let mut prover = self.connect(&mut server_io, &mut verifier_io);
(&mut prover).await?;
{
let mut vm = vm.try_lock().expect("VM should not be locked");
let prover = prover.finish()?;
Ok((prover, server_io, verifier_io))
}
}
debug!("finalizing mpc");
impl<S, T> Future for Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
type Output = Result<(), ProverError>;
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = Pin::new(&mut self.state).project();
loop {
let mut progress = false;
if state.output.is_none()
&& let Poll::Ready(output) = state.tls_client.poll(cx)?
{
*state.output = Some(output);
}
progress |= Self::io_client_conn(&mut state, cx)?;
progress |= Self::io_client_server(&mut state, cx)?;
progress |= Self::io_client_verifier(&mut state, cx)?;
_ = state.mux_fut.poll_unpin(cx)?;
if *state.server_closed && state.output.is_some() {
ready!(state.client_io.poll_close(cx))?;
ready!(state.server_socket.poll_close(cx))?;
return Poll::Ready(Ok(()));
} else if !progress {
return Poll::Pending;
}
}
}
}
impl<S, T> Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
fn io_client_conn(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
// tls_conn -> tls_client
if state.tls_client.wants_write()
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
{
if !buf.is_empty() {
let write = state.tls_client.write(buf)?;
if write > 0 {
progress = true;
simplex.advance(write);
debug!("mpc finalized");
}
} else if !*state.client_closed && !*state.server_closed {
progress = true;
*state.client_closed = true;
state.tls_client.client_close()?;
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
}
.instrument(span)
});
// tls_client -> tls_conn
if state.tls_client.wants_read()
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
&& let read = state.tls_client.read(buf)?
&& read > 0
{
progress = true;
simplex.advance_mut(read);
}
Ok(progress)
}
fn io_client_server(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
// server_socket -> buf
if let Poll::Ready(write) = state
.server_to_tls_client_buf
.poll_write_from(cx, state.server_socket.as_mut())?
{
if write > 0 {
progress = true;
} else if !*state.server_closed {
progress = true;
*state.server_closed = true;
state.tls_client.server_close()?;
}
}
// buf -> tls_client
if state.tls_client.wants_read_tls()
&& let Poll::Ready(mut simplex) =
state.tls_client_to_server_buf.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
&& let read = state.tls_client.read_tls(buf)?
&& read > 0
{
progress = true;
simplex.advance(read);
}
// tls_client -> buf
if state.tls_client.wants_write_tls()
&& let Poll::Ready(mut simplex) =
state.tls_client_to_server_buf.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
&& let write = state.tls_client.write_tls(buf)?
&& write > 0
{
progress = true;
simplex.advance_mut(write);
}
// buf -> server_socket
if let Poll::Ready(read) = state
.server_to_tls_client_buf
.poll_read_to(cx, state.server_socket.as_mut())?
&& read > 0
{
progress = true;
}
Ok(progress)
}
fn io_client_verifier(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
let verifier_io = Pin::new(
(*state.verifier_io)
.as_mut()
.expect("verifier io should be available"),
);
// mux -> verifier_socket
if let Poll::Ready(read) = verifier_io.poll_read_to(cx, state.verifier_socket.as_mut())?
&& read > 0
{
progress = true;
}
// verifier_socket -> mux
if let Poll::Ready(write) =
verifier_io.poll_write_from(cx, state.verifier_socket.as_mut())?
&& write > 0
{
progress = true;
}
Ok(progress)
}
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
verifier_io: self.state.verifier_io,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
},
};
Ok(prover)
))
}
}
@@ -529,30 +310,8 @@ impl Prover<state::Committed> {
/// # Arguments
///
/// * `config` - The disclosure configuration.
/// * `verifier_io` - The IO to the TLS verifier.
pub async fn prove<S>(
&mut self,
config: &ProveConfig,
verifier_io: S,
) -> Result<ProverOutput, ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut duplex = self
.state
.verifier_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.prove_inner(config).fuse());
let output = await_with_copy_io(fut, verifier_io, &mut duplex).await?;
self.state.verifier_io = Some(duplex);
Ok(output)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)]
async fn prove_inner(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed {
mux_fut,
ctx,
@@ -605,31 +364,44 @@ impl Prover<state::Committed> {
}
/// Closes the connection with the verifier.
///
/// # Arguments
///
/// * `verifier_io` - The IO to the TLS verifier.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close<S>(mut self, mut verifier_io: S) -> Result<(), ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
pub async fn close(self) -> Result<(), ProverError> {
let state::Committed {
mux_ctrl, mux_fut, ..
} = self.state;
let mut duplex = self
.state
.verifier_io
.take()
.expect("duplex should be available");
// Wait for the verifier to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
mux_ctrl.close();
let copy = CopyIo::new(&mut verifier_io, &mut duplex).map_err(ProverError::from);
futures::try_join!(mux_fut.map_err(ProverError::from), copy)?;
// Wait for the verifier to finish closing.
verifier_io.read_exact(&mut [0_u8; 5]).await?;
Ok(())
}
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -1,93 +0,0 @@
//! Provides a TLS client.
use crate::{mpz::ProverZk, prover::control::ControlError};
use mpc_tls::SessionKeys;
use std::{
sync::mpsc::{Sender, SyncSender, sync_channel},
task::{Context, Poll},
};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Returns a handle to control the client.
fn handle(&self) -> ClientHandle;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
#[derive(Clone, Debug)]
pub(crate) struct ClientHandle {
sender: Sender<Command>,
}
#[derive(Clone, Debug)]
pub(crate) enum Command {
IsDecrypting(SyncSender<bool>),
SetDecrypt(bool),
ClientClose,
ServerClose,
}
impl ClientHandle {
pub(crate) fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
self.sender
.send(Command::SetDecrypt(enable))
.map_err(|_| ControlError)
}
pub(crate) fn is_decrypting(&self) -> bool {
let (sender, receiver) = sync_channel(1);
let Ok(_) = self.sender.send(Command::IsDecrypting(sender)) else {
return false;
};
receiver.recv().unwrap_or(false)
}
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -1,503 +0,0 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{ClientHandle, Command, TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{LeaderCtrl, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{
pin::Pin,
sync::{
Arc,
mpsc::{Receiver, Sender, channel},
},
task::Poll,
};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture =
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
pub(crate) struct MpcTlsClient {
sender: Sender<Command>,
state: State,
decrypt: bool,
}
enum State {
Start {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
receiver: Receiver<Command>,
},
Active {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
receiver: Receiver<Command>,
},
Busy {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
receiver: Receiver<Command>,
},
MpcStop {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
},
CloseBusy {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finishing {
ctx: Context,
transcript: Box<TlsTranscript>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
mpc: MpcFuture,
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
mpc_ctrl: LeaderCtrl,
tls: ClientConnection,
decrypt: bool,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_ctrl,
client_closed: false,
mpc_stopped: false,
};
let (sender, receiver) = channel();
Self {
sender,
decrypt,
state: State::Start {
receiver,
mpc: Box::into_pin(mpc),
inner: Box::new(inner),
},
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
self.sender
.send(Command::ClientClose)
.map_err(|_| ProverError::state("unable to close connection clientside"))
}
fn server_close(&mut self) -> Result<(), Self::Error> {
self.sender
.send(Command::ServerClose)
.map_err(|_| ProverError::state("unable to close connection serverside"))
}
fn handle(&self) -> ClientHandle {
ClientHandle {
sender: self.sender.clone(),
}
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start {
mpc,
inner,
receiver,
} => {
trace!("inner client is starting");
self.state = State::Busy {
mpc,
fut: Box::pin(inner.start()),
receiver,
};
self.poll(cx)
}
State::Active {
mpc,
inner,
receiver,
} => {
trace!("inner client is active");
if !inner.tls.is_handshaking()
&& let Ok(cmd) = receiver.try_recv()
{
match cmd {
Command::ClientClose => {
self.state = State::Busy {
mpc,
fut: Box::pin(inner.client_close()),
receiver,
};
}
Command::ServerClose => {
std::mem::drop(receiver);
self.state = State::CloseBusy {
mpc,
fut: Box::pin(inner.server_close()),
};
}
Command::SetDecrypt(enable) => {
self.decrypt = enable;
self.state = State::Busy {
mpc,
fut: Box::pin(inner.set_decrypt(enable)),
receiver,
};
}
Command::IsDecrypting(sender) => {
_ = sender.send(self.decrypt);
self.state = State::Busy {
mpc,
fut: Box::pin(inner.run()),
receiver,
};
}
}
} else {
self.state = State::Busy {
mpc,
fut: Box::pin(inner.run()),
receiver,
};
}
self.poll(cx)
}
State::Busy {
mut mpc,
mut fut,
receiver,
} => {
trace!("inner client is busy");
let mpc_poll = mpc.as_mut().poll(cx)?;
assert!(
matches!(mpc_poll, Poll::Pending),
"mpc future should not be finished here"
);
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active {
mpc,
inner,
receiver,
};
}
Poll::Pending => self.state = State::Busy { mpc, fut, receiver },
}
Poll::Pending
}
State::MpcStop { mpc, inner } => {
trace!("inner client is stopping mpc");
self.state = State::CloseBusy {
mpc,
fut: Box::pin(inner.stop()),
};
self.poll(cx)
}
State::CloseBusy { mut mpc, mut fut } => {
trace!("inner client is busy closing");
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
self.poll(cx)
}
(Poll::Ready(inner), Poll::Pending) => {
self.state = State::MpcStop { mpc, inner };
Poll::Pending
}
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
self.state = State::Finishing {
ctx,
transcript: Box::new(transcript),
fut,
};
Poll::Pending
}
(Poll::Pending, Poll::Pending) => {
self.state = State::CloseBusy { mpc, fut };
Poll::Pending
}
}
}
State::Finishing {
ctx,
transcript,
mut fut,
} => {
trace!("inner client is finishing");
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, *transcript)),
};
self.poll(cx)
} else {
self.state = State::Finishing {
ctx,
transcript,
fut,
};
Poll::Pending
}
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
Poll::Pending
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_ctrl: LeaderCtrl,
client_closed: bool,
mpc_stopped: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
self.mpc_ctrl.enable_decryption(enable).await?;
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
if !self.client_closed {
debug!("sending close notify");
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
}
self.client_closed = true;
}
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
self.tls.server_closed().await?;
debug!("closed connection serverside");
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn stop(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
self.mpc_ctrl.stop().await?;
self.mpc_stopped = true;
debug!("stopped mpc");
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -1,66 +0,0 @@
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt};
use futures_plex::DuplexStream;
use std::{
pin::Pin,
task::{Context, Poll},
};
/// A TLS connection to a server.
///
/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
pub struct TlsConnection {
duplex: DuplexStream,
}
impl TlsConnection {
pub(crate) fn new(duplex: DuplexStream) -> Self {
Self { duplex }
}
}
impl Drop for TlsConnection {
fn drop(&mut self) {
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
tracing::error!("error closing connection: {}", err);
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_close(cx)
}
}

View File

@@ -1,29 +0,0 @@
use crate::prover::client::ClientHandle;
/// A controller for the prover.
///
/// Can be used to control the decryption of server traffic.
#[derive(Clone, Debug)]
pub struct ProverControl {
pub(crate) handle: ClientHandle,
}
impl ProverControl {
/// Returns whether the prover is decrypting the server traffic.
pub fn is_decrypting(&self) -> bool {
self.handle.is_decrypting()
}
/// Enables or disables the decryption of server traffic.
///
/// # Arguments
///
/// * `enable` - If decryption should be enabled or disabled.
pub fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
self.handle.enable_decryption(enable)
}
}
#[derive(Debug, thiserror::Error)]
#[error("Unable to send control command to prover.")]
pub struct ControlError;

View File

@@ -49,13 +49,6 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
}
#[derive(Debug)]
@@ -65,7 +58,6 @@ enum ErrorKind {
Zk,
Config,
Commit,
State,
}
impl fmt::Display for ProverError {
@@ -78,7 +70,6 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
}
if let Some(source) = &self.source {
@@ -95,8 +86,8 @@ impl From<std::io::Error> for ProverError {
}
}
impl From<tls_client::Error> for ProverError {
fn from(e: tls_client::Error) -> Self {
impl From<tls_client_async::ConnectionError> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self {
Self::new(ErrorKind::Io, e)
}
}

View File

@@ -0,0 +1,32 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -2,7 +2,6 @@
use std::sync::Arc;
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -15,10 +14,6 @@ use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
};
/// Entry state
@@ -29,7 +24,6 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
@@ -39,49 +33,8 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted);
/// State when the TLS client has been setup.
pub struct Setup {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
pub(crate) client_io: DuplexStream,
}
opaque_debug::implement!(Setup);
pin_project_lite::pin_project! {
/// State during the MPC-TLS connection.
#[project = ConnectedProj]
pub struct Connected<S, T> {
#[pin]
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
#[pin]
pub(crate) client_io: DuplexStream,
pub(crate) output: Option<TlsOutput>,
#[pin]
pub(crate) server_socket: S,
#[pin]
pub(crate) verifier_socket: T,
#[pin]
pub(crate) tls_client_to_server_buf: DuplexStream,
#[pin]
pub(crate) server_to_tls_client_buf: DuplexStream,
pub(crate) client_closed: bool,
pub(crate) server_closed: bool
}
}
opaque_debug::implement!(Connected<S, T>);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -99,15 +52,11 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for CommitAccepted {}
impl ProverState for Setup {}
impl<S, T> ProverState for Connected<S, T> {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Setup {}
impl<S, T> Sealed for super::Connected<S, T> {}
impl Sealed for super::Committed {}
}

View File

@@ -1,143 +0,0 @@
//! Execution context.
use std::{
io::ErrorKind,
pin::Pin,
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite, future::FusedFuture};
use futures_plex::DuplexStream;
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub(crate) const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}
/// Polls the future while copying bytes between two duplex streams.
///
/// Returns as soon as the future is ready, without closing IO.
pub(crate) async fn await_with_copy_io<'a, S, T>(
mut fut: Pin<Box<dyn FusedFuture<Output = T> + Send + 'a>>,
io: S,
duplex: &mut DuplexStream,
) -> T
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut copy = CopyIo::new(io, duplex);
loop {
futures::select! {
_ = copy => (),
output = fut => break output
}
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub(crate) struct CopyIo<'a, S> {
#[pin]
io: S,
#[pin]
duplex: &'a mut DuplexStream,
io_done: bool,
duplex_done: bool,
}
}
impl<'a, S> CopyIo<'a, S> {
pub(crate) fn new(io: S, duplex: &'a mut DuplexStream) -> Self {
Self {
io,
duplex,
io_done: false,
duplex_done: false,
}
}
}
impl<'a, S> Future for CopyIo<'a, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
type Output = std::io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
let mut is_pending = true;
if !*this.duplex_done {
match this.duplex.poll_read_to(cx, this.io.as_mut()) {
Poll::Ready(Ok(read)) if read > 0 => is_pending = false,
Poll::Ready(Ok(_)) => {
is_pending = false;
*this.duplex_done = true;
}
Poll::Ready(Err(err))
if err.kind() == ErrorKind::BrokenPipe
|| err.kind() == ErrorKind::ConnectionReset
|| err.kind() == ErrorKind::NotConnected =>
{
is_pending = false;
*this.duplex_done = true;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
if !*this.io_done {
match this.duplex.poll_write_from(cx, this.io.as_mut()) {
Poll::Ready(Ok(write)) if write > 0 => is_pending = false,
Poll::Ready(Ok(_)) => {
is_pending = false;
*this.io_done = true;
}
Poll::Ready(Err(err))
if err.kind() == ErrorKind::BrokenPipe
|| err.kind() == ErrorKind::ConnectionReset
|| err.kind() == ErrorKind::NotConnected =>
{
is_pending = false;
*this.io_done = true
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
if *this.io_done || *this.duplex_done {
return Poll::Ready(Ok(()));
} else if is_pending {
return Poll::Pending;
}
}
}
}
impl<'a, S> FusedFuture for CopyIo<'a, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
fn is_terminated(&self) -> bool {
self.duplex_done || self.io_done
}
}

View File

@@ -10,14 +10,14 @@ pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{
BUF_CAP, Role,
Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
utils::{CopyIo, await_with_copy_io, build_mt_context},
};
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, TryFutureExt};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
@@ -66,73 +66,48 @@ impl Verifier<state::Initialized> {
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
mut prover_io: S,
socket: S,
) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_a, Role::Verifier);
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
let fut = Box::pin(
async {
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Receives protocol configuration from prover to perform compatibility check.
let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
Ok::<_, VerifierError>((request, version, ctx))
}
.fuse(),
);
let (request, version, mut ctx) =
await_with_copy_io(fut, &mut prover_io, &mut duplex_b).await?;
// Receives protocol configuration from prover to perform compatibility check.
let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
if version != *crate::VERSION {
let msg = format!(
"prover version does not match with verifier: {version} != {}",
*crate::VERSION
);
mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
.await?;
let fut = Box::pin(
async {
mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
.await?;
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
Err(VerifierError::config(msg))
}
.fuse(),
);
let copy = CopyIo::new(prover_io, &mut duplex_b).map_err(VerifierError::from);
let (config_err, _) = futures::try_join!(fut, copy)?;
return Err(config_err);
return Err(VerifierError::config(msg));
}
let verifier = Verifier {
Ok(Verifier {
config: self.config,
span: self.span,
state: state::CommitStart {
prover_io: Some(duplex_b),
mux_ctrl,
mux_fut,
ctx,
request,
},
};
Ok(verifier)
})
}
}
@@ -143,36 +118,13 @@ impl Verifier<state::CommitStart> {
}
/// Accepts the proposed protocol configuration.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.accept_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)]
async fn accept_inner(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart {
prover_io,
mux_ctrl,
mut mux_fut,
mut ctx,
request,
..
} = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
@@ -199,7 +151,6 @@ impl Verifier<state::CommitStart> {
config: self.config,
span: self.span,
state: state::CommitAccepted {
prover_io,
mux_ctrl,
mux_fut,
mpc_tls,
@@ -210,31 +161,8 @@ impl Verifier<state::CommitStart> {
}
/// Rejects the proposed protocol configuration.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
/// * `msg` - The optional rejection message.
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
msg: Option<&str>,
) -> Result<(), VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = self.reject_inner(msg);
let copy = CopyIo::new(prover_io, &mut duplex).map_err(VerifierError::from);
futures::try_join!(fut, copy)?;
Ok(())
}
#[instrument(parent = &self.span, level = "info", skip_all, err)]
async fn reject_inner(self, msg: Option<&str>) -> Result<(), VerifierError> {
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> {
let state::CommitStart {
mux_ctrl,
mut mux_fut,
@@ -258,31 +186,9 @@ impl Verifier<state::CommitStart> {
impl Verifier<state::CommitAccepted> {
/// Runs the verifier until the TLS connection is closed.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn run<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::Committed>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.run_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)]
async fn run_inner(self) -> Result<Verifier<state::Committed>, VerifierError> {
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted {
prover_io,
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -326,7 +232,6 @@ impl Verifier<state::CommitAccepted> {
)
.map_err(VerifierError::zk)?;
debug!("verifying tags");
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
@@ -336,12 +241,10 @@ impl Verifier<state::CommitAccepted> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
debug!("MPC-TLS done");
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Committed {
prover_io,
mux_ctrl,
mux_fut,
ctx,
@@ -360,31 +263,9 @@ impl Verifier<state::Committed> {
}
/// Begins verification of statements from the prover.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn verify<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::Verify>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.verify_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)]
async fn verify_inner(self) -> Result<Verifier<state::Verify>, VerifierError> {
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed {
prover_io,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -405,7 +286,6 @@ impl Verifier<state::Committed> {
config: self.config,
span: self.span,
state: state::Verify {
prover_io,
mux_ctrl,
mux_fut,
ctx,
@@ -420,36 +300,18 @@ impl Verifier<state::Committed> {
}
/// Closes the connection with the prover.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
mut prover_io: S,
) -> Result<(), VerifierError> {
let state::Committed { mux_fut, .. } = self.state;
pub async fn close(self) -> Result<(), VerifierError> {
let state::Committed {
mux_ctrl, mux_fut, ..
} = self.state;
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
duplex.close().await?;
let fut: Box<dyn Future<Output = Result<(), VerifierError>> + Send + Unpin> =
if mux_fut.is_complete() {
Box::new(futures::future::ready(Ok::<_, VerifierError>(())))
} else {
Box::new(mux_fut.map_err(VerifierError::from))
};
let copy = CopyIo::new(&mut prover_io, &mut duplex).map_err(VerifierError::from);
futures::try_join!(fut, copy)?;
prover_io.write_all(b"close").await?;
Ok(())
}
}
@@ -461,32 +323,10 @@ impl Verifier<state::Verify> {
}
/// Accepts the proving request.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.accept_inner().fuse());
let (output, mut verifier) = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok((output, verifier))
}
async fn accept_inner(
pub async fn accept(
self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify {
prover_io,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -522,7 +362,6 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
prover_io,
mux_ctrl,
mux_fut,
ctx,
@@ -535,35 +374,11 @@ impl Verifier<state::Verify> {
}
/// Rejects the proving request.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
/// * `msg` - The optional rejection message.
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.reject_inner(msg).fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
async fn reject_inner(
pub async fn reject(
self,
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify {
prover_io,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -581,7 +396,6 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
prover_io,
mux_ctrl,
mux_fut,
ctx,

View File

@@ -3,7 +3,6 @@
use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -26,7 +25,6 @@ opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover.
pub struct CommitStart {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -38,7 +36,6 @@ opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower,
@@ -50,7 +47,6 @@ opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -63,7 +59,6 @@ opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,

View File

@@ -110,11 +110,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
) -> (Transcript, ProverOutput) {
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let client_socket = client_socket.compat();
let server_socket = server_socket.compat();
let mut verifier_socket = verifier_socket.compat();
let server_task = tokio::spawn(bind(server_socket));
let server_task = tokio::spawn(bind(server_socket.compat()));
let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit(
@@ -130,13 +126,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
)
.build()
.unwrap(),
&mut verifier_socket,
verifier_socket.compat(),
)
.await
.unwrap();
let (mut tls_connection, prover) = prover
.setup(
let (mut tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore {
@@ -144,23 +140,24 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
})
.build()
.unwrap(),
client_socket.compat(),
)
.await
.unwrap();
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
let prover_task = tokio::spawn(prover_fut);
tls_connection
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
.await
.unwrap();
tls_connection.close().await.unwrap();
let mut response = vec![0u8; 1024];
tls_connection.read_to_end(&mut response).await.unwrap();
tls_connection.close().await.unwrap();
let _ = server_task.await.unwrap();
let (mut prover, _, mut verifier_socket) = prover_task.await.unwrap().unwrap();
let mut prover = prover_task.await.unwrap().unwrap();
let sent_tx_len = prover.transcript().sent().len();
let recv_tx_len = prover.transcript().received().len();
@@ -199,8 +196,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap();
let transcript = prover.transcript().clone();
let output = prover.prove(&config, &mut verifier_socket).await.unwrap();
prover.close(&mut verifier_socket).await.unwrap();
let output = prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
(transcript, output)
}
@@ -209,8 +206,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> VerifierOutput {
let mut socket = socket.compat();
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
@@ -221,24 +216,18 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
);
let verifier = verifier
.commit(&mut socket)
.commit(socket.compat())
.await
.unwrap()
.accept(&mut socket)
.accept()
.await
.unwrap()
.run(&mut socket)
.run()
.await
.unwrap();
let (output, verifier) = verifier
.verify(&mut socket)
.await
.unwrap()
.accept(&mut socket)
.await
.unwrap();
verifier.close(&mut socket).await.unwrap();
let (output, verifier) = verifier.verify().await.unwrap().accept().await.unwrap();
verifier.close().await.unwrap();
output
}

View File

@@ -23,9 +23,9 @@ no-bundler = ["web-spawn/no-bundler"]
tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
async_io_stream = { version = "0.3" }
bincode = { workspace = true }
console_error_panic_hook = { version = "0.1" }
enum-try-as-inner = { workspace = true }

View File

@@ -2,11 +2,11 @@ mod config;
pub use config::ProverConfig;
use async_io_stream::IoStream;
use enum_try_as_inner::EnumTryAsInner;
use futures::TryFutureExt;
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use tls_client_async::TlsConnection;
use tlsn::{
config::{
prove::ProveConfig,
@@ -14,13 +14,13 @@ use tlsn::{
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
connection::ServerName,
prover::{state, Prover, TlsConnection},
prover::{state, Prover},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tracing::info;
use wasm_bindgen::{prelude::*, JsError};
use wasm_bindgen_futures::spawn_local;
use ws_stream_wasm::{WsMeta, WsStreamIo};
use ws_stream_wasm::WsMeta;
use crate::{io::FuturesIo, types::*};
@@ -36,14 +36,8 @@ pub struct JsProver {
#[derive_err(Debug)]
enum State {
Initialized(Prover<state::Initialized>),
CommitAccepted {
prover: Prover<state::CommitAccepted>,
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
},
Committed {
prover: Prover<state::Committed>,
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
},
CommitAccepted(Prover<state::CommitAccepted>),
Committed(Prover<state::Committed>),
Complete,
Error,
}
@@ -102,16 +96,12 @@ impl JsProver {
info!("connecting to verifier");
let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?;
let mut verifier_conn = verifier_conn.into_io();
info!("connected to verifier");
let prover = prover.commit(config, &mut verifier_conn).await?;
let prover = prover.commit(config, verifier_conn.into_io()).await?;
self.state = State::CommitAccepted {
prover,
verifier_conn,
};
self.state = State::CommitAccepted(prover);
Ok(())
}
@@ -122,7 +112,7 @@ impl JsProver {
ws_proxy_url: &str,
request: HttpRequest,
) -> Result<HttpResponse> {
let (prover, mut verifier_conn) = self.state.take().try_into_commit_accepted()?;
let prover = self.state.take().try_into_commit_accepted()?;
let mut builder = TlsClientConfig::builder()
.server_name(ServerName::Dns(
@@ -155,41 +145,35 @@ impl JsProver {
info!("connecting to server");
let (_, server_conn) = WsMeta::connect(ws_proxy_url, None).await?;
let mut server_conn = server_conn.into_io();
info!("connected to server");
let (tls_conn, prover) = prover.setup(config)?;
let mut prover = prover.connect(&mut server_conn, &mut verifier_conn);
let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?;
info!("sending request");
let (response, _) = futures::try_join!(
let (response, prover) = futures::try_join!(
send_request(tls_conn, request),
(&mut prover).map_err(Into::into)
prover_fut.map_err(Into::into)
)?;
let prover = prover.finish()?;
info!("response received");
self.state = State::Committed {
prover,
verifier_conn,
};
self.state = State::Committed(prover);
Ok(response)
}
/// Returns the transcript.
pub fn transcript(&self) -> Result<Transcript> {
let (prover, _) = self.state.try_as_committed()?;
let prover = self.state.try_as_committed()?;
Ok(Transcript::from(prover.transcript()))
}
/// Reveals data to the verifier and finalizes the protocol.
pub async fn reveal(&mut self, reveal: Reveal) -> Result<()> {
let (mut prover, mut verifier_conn) = self.state.take().try_into_committed()?;
let mut prover = self.state.take().try_into_committed()?;
info!("revealing data");
@@ -209,8 +193,8 @@ impl JsProver {
let config = builder.build()?;
prover.prove(&config, &mut verifier_conn).await?;
prover.close(&mut verifier_conn).await?;
prover.prove(&config).await?;
prover.close().await?;
info!("Finalized");

View File

@@ -2,7 +2,6 @@ mod config;
pub use config::VerifierConfig;
use async_io_stream::IoStream;
use enum_try_as_inner::EnumTryAsInner;
use tlsn::{
config::tls_commit::TlsCommitProtocolConfig,
@@ -13,7 +12,7 @@ use tlsn::{
};
use tracing::info;
use wasm_bindgen::prelude::*;
use ws_stream_wasm::{WsMeta, WsStreamIo};
use ws_stream_wasm::{WsMeta, WsStream};
use crate::types::VerifierOutput;
@@ -27,13 +26,9 @@ pub struct JsVerifier {
#[derive(EnumTryAsInner)]
#[derive_err(Debug)]
#[allow(unused_assignments)]
enum State {
Initialized(Verifier<state::Initialized>),
Connected {
verifier: Verifier<state::Initialized>,
prover_conn: IoStream<WsStreamIo, Vec<u8>>,
},
Connected((Verifier<state::Initialized>, WsStream)),
Complete,
Error,
}
@@ -71,23 +66,19 @@ impl JsVerifier {
info!("Connecting to prover");
let (_, prover_conn) = WsMeta::connect(prover_url, None).await?;
let prover_conn = prover_conn.into_io();
info!("Connected to prover");
self.state = State::Connected {
verifier,
prover_conn,
};
self.state = State::Connected((verifier, prover_conn));
Ok(())
}
/// Verifies the connection and finalizes the protocol.
pub async fn verify(&mut self) -> Result<VerifierOutput> {
let (verifier, mut prover_conn) = self.state.take().try_into_connected()?;
let (verifier, prover_conn) = self.state.take().try_into_connected()?;
let verifier = verifier.commit(&mut prover_conn).await?;
let verifier = verifier.commit(prover_conn.into_io()).await?;
let request = verifier.request();
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol() else {
@@ -107,15 +98,11 @@ impl JsVerifier {
};
if reject.is_some() {
verifier.reject(&mut prover_conn, reject).await?;
verifier.reject(reject).await?;
return Err(JsError::new("protocol configuration rejected"));
}
let verifier = verifier
.accept(&mut prover_conn)
.await?
.run(&mut prover_conn)
.await?;
let verifier = verifier.accept().await?.run().await?;
let sent = verifier
.tls_transcript()
@@ -142,12 +129,8 @@ impl JsVerifier {
},
};
let (output, verifier) = verifier
.verify(&mut prover_conn)
.await?
.accept(&mut prover_conn)
.await?;
verifier.close(&mut prover_conn).await?;
let (output, verifier) = verifier.verify().await?.accept().await?;
verifier.close().await?;
self.state = State::Complete;