Compare commits

..

9 Commits

Author SHA1 Message Date
Hendrik Eeckhaut
b76775fc7c correction + legend placement 2025-12-23 15:11:35 +01:00
Hendrik Eeckhaut
72041d1f07 export dark svg 2025-12-23 14:47:19 +01:00
Hendrik Eeckhaut
ac1df8fc75 Allow plotting multiple data runs 2025-12-23 14:31:54 +01:00
Hendrik Eeckhaut
3cb7c5c0b4 Working on benchmark plots 2025-12-23 14:07:39 +01:00
Hendrik Eeckhaut
b41d678829 build: update Rust to version 1.92.0 2025-12-16 09:36:11 +01:00
sinu.eth
1ebefa27d8 perf(core): fold instead of flatten (#1064) 2025-12-11 06:41:26 -08:00
dan
4fe5c1defd feat(harness): add reveal_all config (#1063) 2025-12-09 12:01:39 +00:00
dan
0e8e547300 chore: adapt for rangeset 0.4 (#1058) 2025-12-09 11:36:13 +00:00
dan
22cc88907a chore: bump mpz (#1057) 2025-12-04 10:27:43 +00:00
57 changed files with 6160 additions and 1755 deletions

View File

@@ -21,7 +21,7 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.91.1
RUST_VERSION: 1.92.0
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:

2593
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ members = [
"crates/server-fixture/server",
"crates/tls/backend",
"crates/tls/client",
"crates/tls/client-async",
"crates/tls/core",
"crates/mpc-tls",
"crates/tls/server-fixture",
@@ -56,6 +57,7 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-harness-core = { path = "crates/harness/core" }
@@ -64,28 +66,27 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0b46dc0" }
rangeset = { version = "0.2" }
rangeset = { version = "0.4" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
aead = { version = "0.4" }
aes = { version = "0.8" }

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{Difference, RangeSet, UnionMut};
use rangeset::{ops::Set, set::RangeSet};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError;
@@ -210,10 +210,12 @@ where
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() {
for input in input_minus_follower {
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
@@ -266,7 +268,7 @@ where
mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -282,7 +284,7 @@ where
mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::Subset;
use rangeset::ops::Set;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]

View File

@@ -59,5 +59,7 @@ generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -1,6 +1,6 @@
//! Proving configuration.
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use rangeset::set::{RangeSet, ToRangeSet};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};

View File

@@ -1,11 +1,11 @@
use rangeset::RangeSet;
use rangeset::set::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter_ranges() {
for range in self.0.iter() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;

View File

@@ -26,7 +26,11 @@ mod tls;
use std::{fmt, ops::Range};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -106,8 +110,14 @@ impl Transcript {
}
Some(
Subsequence::new(idx.clone(), data.index_ranges(idx))
.expect("data is same length as index"),
Subsequence::new(
idx.clone(),
data.index(idx).fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
)
.expect("data is same length as index"),
)
}
@@ -129,11 +139,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
for range in sent_idx.iter_ranges() {
for range in sent_idx.iter() {
sent[range.clone()].copy_from_slice(&self.sent[range]);
}
for range in recv_idx.iter_ranges() {
for range in recv_idx.iter() {
received[range.clone()].copy_from_slice(&self.received[range]);
}
@@ -186,12 +196,20 @@ pub struct CompressedPartialTranscript {
impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self {
Self {
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx),
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
Vec::new(),
|mut acc, s| {
acc.extend_from_slice(s);
acc
},
),
received_authed: uncompressed
.received
.index_ranges(&uncompressed.received_authed_idx),
.index(&uncompressed.received_authed_idx)
.fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -207,7 +225,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.sent_idx.iter_ranges() {
for range in compressed.sent_idx.iter() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
@@ -215,7 +233,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.recv_idx.iter_ranges() {
for range in compressed.recv_idx.iter() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
@@ -304,12 +322,16 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len()).difference(&self.sent_authed_idx)
(0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len()).difference(&self.received_authed_idx)
(0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -319,7 +341,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.iter().map(|i| data[i])
authed.iter_values().map(move |i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -339,24 +361,20 @@ impl PartialTranscript {
"received data are not the same length"
);
for range in other
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed_idx
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx
.union_mut(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -368,11 +386,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
self.sent_authed_idx.union_mut(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
self.received_authed_idx.union_mut(&seq.idx);
}
}
}
@@ -383,10 +401,10 @@ impl PartialTranscript {
///
/// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter_ranges() {
for range in self.sent_unauthed().iter() {
self.sent[range].fill(value);
}
for range in self.received_unauthed().iter_ranges() {
for range in self.received_unauthed().iter() {
self.received[range].fill(value);
}
}
@@ -401,13 +419,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
for r in range.difference(&self.sent_authed_idx) {
self.sent[r].fill(value);
}
}
Direction::Received => {
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
for r in range.difference(&self.received_authed_idx) {
self.received[r].fill(value);
}
}
}
@@ -485,7 +503,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0;
for range in self.idx.iter_ranges() {
for range in self.idx.iter() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len();
}
@@ -610,12 +628,7 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
partial_transcript.sent_total = end - 1;

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt};
use rangeset::{ToRangeSet, UnionMut};
use rangeset::set::ToRangeSet;
use serde::{Deserialize, Serialize};
use crate::{

View File

@@ -1,6 +1,6 @@
use std::{collections::HashMap, fmt};
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use crate::{
@@ -103,7 +103,7 @@ impl EncodingProof {
}
expected_leaf.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
}
expected_leaf.extend_from_slice(blinder.as_bytes());

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use crate::{
@@ -99,7 +99,7 @@ impl EncodingTree {
let blinder: Blinder = rand::random();
encoding.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
provider
.provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;

View File

@@ -1,6 +1,10 @@
//! Transcript proofs.
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
@@ -144,7 +148,7 @@ impl TranscriptProof {
}
buffer.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
buffer.extend_from_slice(&plaintext[range]);
}
@@ -366,7 +370,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx);
} else {
let missing = idx.difference(committed);
let missing = idx.difference(committed).into_set();
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!(
@@ -582,7 +586,7 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rangeset::RangeSet;
use rangeset::prelude::*;
use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};

View File

@@ -9,6 +9,7 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench {
group: None,
@@ -20,6 +21,7 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096,
defer_decryption: true,
memory_profile: false,
reveal_all: true,
};
#[derive(Deserialize)]
@@ -79,6 +81,8 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -97,6 +101,8 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
impl BenchItem {
@@ -132,6 +138,10 @@ impl BenchItem {
if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile;
}
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
}
pub fn into_bench(&self) -> Bench {
@@ -145,6 +155,7 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
}
}
}
@@ -164,6 +175,8 @@ pub struct Bench {
pub defer_decryption: bool,
#[serde(rename = "memory-profile")]
pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -22,7 +22,10 @@ pub enum CmdOutput {
GetTests(Vec<String>),
Test(TestOutput),
Bench(BenchOutput),
Fail { reason: Option<String> },
#[cfg(target_arch = "wasm32")]
Fail {
reason: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -98,14 +98,27 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder
.server_identity()
.reveal_sent(&(0..sent_len))?
.reveal_recv(&(0..recv_len))?;
.reveal_sent(&reveal_sent_range)?
.reveal_recv(&reveal_recv_range)?;
let config = builder.build()?;
let prove_config = builder.build()?;
prover.prove(&config).await?;
prover.prove(&prove_config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -7,10 +7,9 @@ publish = false
[dependencies]
tlsn-harness-core = { workspace = true }
# tlsn-server-fixture = { workspace = true }
charming = { version = "0.5.1", features = ["ssr"] }
csv = "1.3.0"
charming = { version = "0.6.0", features = ["ssr"] }
clap = { workspace = true, features = ["derive", "env"] }
itertools = "0.14.0"
polars = { version = "0.44", features = ["csv", "lazy"] }
toml = { workspace = true }

View File

@@ -0,0 +1,111 @@
# TLSNotary Benchmark Plot Tool
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
## Usage
```bash
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
```
### Arguments
- `<TOML>` - Path to Bench.toml file defining benchmark structure
- `<CSV>...` - One or more CSV files with benchmark results
### Options
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
- Number of labels must match number of CSV files
- `--min-max-band` - Add min/max bands to plots showing variance
- `-h, --help` - Print help information
## Examples
### Single Dataset
```bash
tlsn-harness-plot bench.toml results.csv
```
Generates plots from a single benchmark run.
### Compare Two Runs
```bash
tlsn-harness-plot bench.toml before.csv after.csv \
--labels "Before Optimization" "After Optimization"
```
Overlays two datasets to compare performance improvements.
### Multiple Datasets
```bash
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
--labels "Native" "Browser" "WASM"
```
Compare three different runtime environments.
### With Min/Max Bands
```bash
tlsn-harness-plot bench.toml run1.csv run2.csv \
--labels "Config A" "Config B" \
--min-max-band
```
Shows variance ranges for each dataset.
## Output Files
The tool generates two files per benchmark group:
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
- `<output>.svg` - Static SVG image for documentation
Default output filenames:
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
## Plot Format
Each dataset displays:
- **Solid line** - Total runtime (preprocessing + online phase)
- **Dashed line** - Online phase only
- **Shaded area** (optional) - Min/max variance bands
Different datasets automatically use distinct colors for easy comparison.
## CSV Format
Expected columns in each CSV file:
- `group` - Benchmark group name (must match TOML)
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
- `latency` - Network latency in ms (for latency plots)
- `time_preprocess` - Preprocessing time in ms
- `time_online` - Online phase time in ms
- `time_total` - Total runtime in ms
## TOML Format
The benchmark TOML file defines groups with either:
```toml
[[group]]
name = "my_benchmark"
protocol_latency = 50 # Fixed latency for bandwidth plots
# OR
bandwidth = 10000 # Fixed bandwidth for latency plots
```
All datasets must use the same TOML file to ensure consistent benchmark structure.
## Tips
- Use descriptive labels to make plots self-documenting
- Keep CSV files from the same benchmark configuration for valid comparisons
- Min/max bands are useful for showing stability but can clutter plots with many datasets
- Interactive HTML plots support zooming and hovering for detailed values

View File

@@ -1,17 +1,18 @@
use std::f32;
use charming::{
Chart, HtmlRenderer,
Chart, HtmlRenderer, ImageRenderer,
component::{Axis, Legend, Title},
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
element::{
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
Trigger,
},
series::Line,
theme::Theme,
};
use clap::Parser;
use harness_core::bench::{BenchItems, Measurement};
use itertools::Itertools;
const THEME: Theme = Theme::Default;
use harness_core::bench::BenchItems;
use polars::prelude::*;
#[derive(Parser, Debug)]
#[command(author, version, about)]
@@ -19,72 +20,131 @@ struct Cli {
/// Path to the Bench.toml file with benchmark spec
toml: String,
/// Path to the CSV file with benchmark results
csv: String,
/// Paths to CSV files with benchmark results (one or more)
csv: Vec<String>,
/// Prover kind: native or browser
#[arg(short, long, value_enum, default_value = "native")]
prover_kind: ProverKind,
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
#[arg(short, long, num_args = 0..)]
labels: Vec<String>,
/// Add min/max bands to plots
#[arg(long, default_value_t = false)]
min_max_band: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ProverKind {
Native,
Browser,
}
impl std::fmt::Display for ProverKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProverKind::Native => write!(f, "Native"),
ProverKind::Browser => write!(f, "Browser"),
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let mut rdr = csv::Reader::from_path(&cli.csv)?;
if cli.csv.is_empty() {
return Err("At least one CSV file must be provided".into());
}
// Generate labels if not provided
let labels: Vec<String> = if cli.labels.is_empty() {
cli.csv
.iter()
.enumerate()
.map(|(i, _)| format!("Dataset {}", i + 1))
.collect()
} else if cli.labels.len() != cli.csv.len() {
return Err(format!(
"Number of labels ({}) must match number of CSV files ({})",
cli.labels.len(),
cli.csv.len()
)
.into());
} else {
cli.labels.clone()
};
// Load all CSVs and add dataset label
let mut dfs = Vec::new();
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
let mut df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
.finish()?;
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
df.with_column(label_series)?;
dfs.push(df);
}
// Combine all dataframes
let df = dfs
.into_iter()
.reduce(|acc, df| acc.vstack(&df).unwrap())
.unwrap();
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
let groups = items.group;
// Prepare data for plotting.
let all_data: Vec<Measurement> = rdr
.deserialize::<Measurement>()
.collect::<Result<Vec<_>, _>>()?;
for group in groups {
if group.protocol_latency.is_some() {
let latency = group.protocol_latency.unwrap();
plot_runtime_vs(
&all_data,
cli.min_max_band,
&group.name,
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
"runtime_vs_bandwidth.html",
"Bandwidth (Mbps)",
)?;
// Determine which field varies in benches for this group
let benches_in_group: Vec<_> = items
.bench
.iter()
.filter(|b| b.group.as_deref() == Some(&group.name))
.collect();
if benches_in_group.is_empty() {
continue;
}
if group.bandwidth.is_some() {
let bandwidth = group.bandwidth.unwrap();
// Check which field has varying values
let bandwidth_varies = benches_in_group
.windows(2)
.any(|w| w[0].bandwidth != w[1].bandwidth);
let latency_varies = benches_in_group
.windows(2)
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
let download_size_varies = benches_in_group
.windows(2)
.any(|w| w[0].download_size != w[1].download_size);
if download_size_varies {
let upload_size = group.upload_size.unwrap_or(1024);
plot_runtime_vs(
&all_data,
&df,
&labels,
cli.min_max_band,
&group.name,
|r| r.latency as f32,
"download_size",
1.0 / 1024.0, // bytes to KB
"Runtime vs Response Size",
format!("{} bytes upload size", upload_size),
"runtime_vs_download_size",
"Response Size (KB)",
true, // legend on left
)?;
} else if bandwidth_varies {
let latency = group.protocol_latency.unwrap_or(50);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"bandwidth",
1.0 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency", latency),
"runtime_vs_bandwidth",
"Bandwidth (Mbps)",
false, // legend on right
)?;
} else if latency_varies {
let bandwidth = group.bandwidth.unwrap_or(1000);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"latency",
1.0,
"Runtime vs Latency",
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
"runtime_vs_latency.html",
format!("{} bps bandwidth", bandwidth),
"runtime_vs_latency",
"Latency (ms)",
true, // legend on left
)?;
}
}
@@ -92,83 +152,51 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
struct DataPoint {
min: f32,
mean: f32,
max: f32,
}
struct Points {
preprocess: DataPoint,
online: DataPoint,
total: DataPoint,
}
#[allow(clippy::too_many_arguments)]
fn plot_runtime_vs<Fx>(
all_data: &[Measurement],
fn plot_runtime_vs(
df: &DataFrame,
labels: &[String],
show_min_max: bool,
group: &str,
x_value: Fx,
x_col: &str,
x_scale: f32,
title: &str,
subtitle: String,
output_file: &str,
x_axis_label: &str,
) -> Result<Chart, Box<dyn std::error::Error>>
where
Fx: Fn(&Measurement) -> f32,
{
fn data_point(values: &[f32]) -> DataPoint {
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
DataPoint { min, mean, max }
}
legend_left: bool,
) -> Result<Chart, Box<dyn std::error::Error>> {
let stats_df = df
.clone()
.lazy()
.filter(col("group").eq(lit(group)))
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
.with_columns([
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
])
.group_by([col("x"), col("dataset_label")])
.agg([
col("preprocess").min().alias("preprocess_min"),
col("preprocess").mean().alias("preprocess_mean"),
col("preprocess").max().alias("preprocess_max"),
col("online").min().alias("online_min"),
col("online").mean().alias("online_mean"),
col("online").max().alias("online_max"),
col("total").min().alias("total_min"),
col("total").mean().alias("total_mean"),
col("total").max().alias("total_max"),
])
.sort(["dataset_label", "x"], Default::default())
.collect()?;
let stats: Vec<(f32, Points)> = all_data
.iter()
.filter(|r| r.group.as_deref() == Some(group))
.map(|r| {
(
x_value(r),
r.time_preprocess as f32 / 1000.0, // ms to s
r.time_online as f32 / 1000.0,
r.time_total as f32 / 1000.0,
)
})
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.chunk_by(|entry| entry.0)
.into_iter()
.map(|(x, group)| {
let group_vec: Vec<_> = group.collect();
let preprocess = data_point(
&group_vec
.iter()
.map(|(_, t, _, _)| *t)
.collect::<Vec<f32>>(),
);
let online = data_point(
&group_vec
.iter()
.map(|(_, _, t, _)| *t)
.collect::<Vec<f32>>(),
);
let total = data_point(
&group_vec
.iter()
.map(|(_, _, _, t)| *t)
.collect::<Vec<f32>>(),
);
(
x,
Points {
preprocess,
online,
total,
},
)
})
.collect();
// Build legend entries
let mut legend_data = Vec::new();
for label in labels {
legend_data.push(format!("Total Mean ({})", label));
legend_data.push(format!("Online Mean ({})", label));
}
let mut chart = Chart::new()
.title(
@@ -179,14 +207,6 @@ where
.subtext_style(TextStyle::new().font_size(16)),
)
.tooltip(Tooltip::new().trigger(Trigger::Axis))
.legend(
Legend::new()
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
.top("80")
.right("110")
.orient(Orient::Vertical)
.item_gap(10),
)
.x_axis(
Axis::new()
.name(x_axis_label)
@@ -205,73 +225,156 @@ where
.name_text_style(TextStyle::new().font_size(21)),
);
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
// Add legend with conditional positioning
let legend = Legend::new()
.data(legend_data)
.top("80")
.orient(Orient::Vertical)
.item_gap(10);
if show_min_max {
chart = add_min_max_band(
chart,
&stats,
"Preprocess Min/Max",
|p| &p.preprocess,
"#ccc",
);
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
let legend = if legend_left {
legend.left("110")
} else {
legend.right("110")
};
chart = chart.legend(legend);
// Define colors for each dataset
let colors = vec![
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
];
for (idx, label) in labels.iter().enumerate() {
let color = colors.get(idx % colors.len()).unwrap();
// Total time - solid line
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Total Mean ({})", label),
"total_mean",
false,
color,
)?;
// Online time - dashed line (same color as total)
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Online Mean ({})", label),
"online_mean",
true,
color,
)?;
if show_min_max {
chart = add_dataset_min_max_band(
&chart,
&stats_df,
label,
&format!("Total Min/Max ({})", label),
"total",
color,
)?;
}
}
// Save the chart as HTML file.
// Save the chart as HTML file (no theme)
HtmlRenderer::new(title, 1000, 800)
.theme(THEME)
.save(&chart, output_file)
.save(&chart, &format!("{}.html", output_file))
.unwrap();
// Save SVG with default theme
ImageRenderer::new(1000, 800)
.theme(Theme::Default)
.save(&chart, &format!("{}.svg", output_file))
.unwrap();
// Save SVG with dark theme
ImageRenderer::new(1000, 800)
.theme(Theme::Dark)
.save(&chart, &format!("{}_dark.svg", output_file))
.unwrap();
Ok(chart)
}
fn add_mean_series(
chart: Chart,
stats: &[(f32, Points)],
name: &str,
extract: impl Fn(&Points) -> f32,
) -> Chart {
chart.series(
Line::new()
.name(name)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points)])
.collect(),
)
.symbol_size(6),
)
fn add_dataset_series(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
series_name: &str,
col_name: &str,
dashed: bool,
color: &str,
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let y = filtered.column(col_name)?.f32()?;
let data: Vec<Vec<f32>> = x
.into_iter()
.zip(y.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let mut line = Line::new()
.name(series_name)
.data(data)
.symbol_size(6)
.item_style(ItemStyle::new().color(color));
let mut line_style = LineStyle::new();
if dashed {
line_style = line_style.type_(LineStyleType::Dashed);
}
line = line.line_style(line_style.color(color));
Ok(chart.clone().series(line))
}
fn add_min_max_band(
chart: Chart,
stats: &[(f32, Points)],
fn add_dataset_min_max_band(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
name: &str,
extract: impl Fn(&Points) -> &DataPoint,
col_prefix: &str,
color: &str,
) -> Chart {
chart.series(
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
let max_data: Vec<Vec<f32>> = x
.into_iter()
.zip(max_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let min_data: Vec<Vec<f32>> = x
.into_iter()
.zip(min_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.rev()
.collect();
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
Ok(chart.clone().series(
Line::new()
.name(name)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points).max])
.chain(
stats
.iter()
.rev()
.map(|(x, points)| vec![*x, extract(points).min]),
)
.collect(),
)
.data(data)
.show_symbol(false)
.line_style(LineStyle::new().opacity(0.0))
.area_style(AreaStyle::new().opacity(0.3).color(color)),
)
))
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,25 @@
#### Bandwidth ####
[[group]]
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
[[bench]]
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -0,0 +1,37 @@
[[group]]
name = "download_size"
protocol_latency = 10
bandwidth = 200
upload-size = 2048
[[bench]]
group = "download_size"
download-size = 1024
[[bench]]
group = "download_size"
download-size = 2048
[[bench]]
group = "download_size"
download-size = 4096
[[bench]]
group = "download_size"
download-size = 8192
[[bench]]
group = "download_size"
download-size = 16384
[[bench]]
group = "download_size"
download-size = 32768
[[bench]]
group = "download_size"
download-size = 65536
[[bench]]
group = "download_size"
download-size = 131072

View File

@@ -0,0 +1,25 @@
#### Latency ####
[[group]]
name = "latency"
bandwidth = 1000
[[bench]]
group = "latency"
protocol_latency = 10
[[bench]]
group = "latency"
protocol_latency = 25
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200

View File

@@ -34,6 +34,7 @@ mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true }
async-trait = { workspace = true }
@@ -65,6 +66,7 @@ rand_chacha = { workspace = true }
rstest = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }

View File

@@ -15,6 +15,13 @@ impl MpcTlsError {
Self(ErrorRepr::Peer(err.into()))
}
pub(crate) fn actor<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Actor(err.into()))
}
pub(crate) fn state<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
@@ -65,6 +72,8 @@ enum ErrorRepr {
Peer(Box<dyn std::error::Error + Send + Sync>),
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("actor error: {0}")]
Actor(Box<dyn std::error::Error + Send + Sync>),
#[error("state error: {0}")]
State(Box<dyn std::error::Error + Send + Sync>),
#[error("allocation error: {0}")]

View File

@@ -1,3 +1,5 @@
mod actor;
use crate::{
error::MpcTlsError,
msg::{
@@ -12,6 +14,7 @@ use async_trait::async_trait;
use hmac_sha256::{MpcPrf, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use ludi::Context as LudiContext;
use mpz_common::{Context, Flush};
use mpz_core::{bitvec::BitVec, Block};
use mpz_memory_core::DecodeFutureTyped;
@@ -47,9 +50,13 @@ use tlsn_core::{
};
use tracing::{debug, instrument, trace, warn};
/// Controller for MPC-TLS leader.
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
/// MPC-TLS leader.
#[derive(Debug)]
pub struct MpcTlsLeader {
self_handle: Option<LeaderCtrl>,
config: Config,
state: State,
@@ -107,6 +114,7 @@ impl MpcTlsLeader {
let is_decrypting = !config.defer_decryption;
Self {
self_handle: None,
config,
state: State::Init {
ctx,
@@ -370,42 +378,18 @@ impl MpcTlsLeader {
Ok(())
}
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
/// Defers decryption of any incoming messages.
#[instrument(level = "debug", skip_all, err)]
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = enable;
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
self.is_decrypting = false;
self.notifier.clear();
Ok(())
}
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Returns the context and transcript.
///
/// Should be called after a successful call to [`Backend::server_closed`].
pub fn finish(&mut self) -> Option<(Context, TlsTranscript)> {
match self.state.take() {
State::Closed {
ctx, transcript, ..
} => Some((ctx, transcript)),
state => {
self.state = state;
None
}
}
/// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop();
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ pub(crate) mod utils;
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
pub use error::MpcTlsError;
pub use follower::MpcTlsFollower;
pub use leader::MpcTlsLeader;
pub use leader::{LeaderCtrl, MpcTlsLeader};
use std::{future::Future, pin::Pin, sync::Arc};

View File

@@ -0,0 +1,160 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.max_recv(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
rcot_recv_a,
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b,
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
);
(leader, follower)
}

View File

@@ -0,0 +1,39 @@
[package]
name = "tlsn-tls-client-async"
authors = ["TLSNotary Team"]
description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "tls_client_async"
[features]
default = ["tracing"]
tracing = ["dep:tracing"]
[dependencies]
tlsn-tls-client = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tokio-util = { workspace = true, features = ["io", "compat"] }
tracing = { workspace = true, optional = true }
[dev-dependencies]
tls-server-fixture = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -0,0 +1,89 @@
use bytes::Bytes;
use futures::{
channel::mpsc::{Receiver, SendError, Sender},
sink::SinkMapErr,
AsyncRead, AsyncWrite, SinkExt,
};
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
pin::Pin,
task::{Context, Poll},
};
use tokio_util::{
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
io::{CopyToBytes, SinkWriter, StreamReader},
};
type CompatSinkWriter =
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
/// A TLS connection to a server.
///
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
#[derive(Debug)]
pub struct TlsConnection {
/// The data to be transmitted to the server is sent to this sink.
tx_sender: CompatSinkWriter,
/// The data to be received from the server is received from this stream.
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
}
impl TlsConnection {
/// Creates a new TLS connection.
pub(crate) fn new(
tx_sender: Sender<Bytes>,
rx_receiver: Receiver<Result<Bytes, IoError>>,
) -> Self {
fn convert_error(err: SendError) -> IoError {
if err.is_disconnected() {
IoErrorKind::BrokenPipe.into()
} else {
IoErrorKind::WouldBlock.into()
}
}
Self {
tx_sender: SinkWriter::new(CopyToBytes::new(
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
))
.compat_write(),
rx_receiver: StreamReader::new(rx_receiver).compat(),
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_close(cx)
}
}

View File

@@ -0,0 +1,269 @@
//! Provides a TLS client which exposes an async socket.
//!
//! This library provides the [bind_client] function which attaches a TLS client
//! to a socket connection and then exposes a [TlsConnection] object, which
//! provides an async socket API for reading and writing cleartext. The TLS
//! client will then automatically encrypt and decrypt traffic and forward that
//! to the provided socket.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod conn;
use bytes::{Buf, Bytes};
use futures::{
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, trace, warn, Instrument};
use tls_client::ClientConnection;
pub use conn::TlsConnection;
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
/// An error that can occur during a TLS connection.
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
TlsError(#[from] tls_client::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
}
/// Closed connection data.
#[derive(Debug)]
pub struct ClosedConnection {
/// The connection for the client
pub client: ClientConnection,
/// Sent plaintext bytes
pub sent: Vec<u8>,
/// Received plaintext bytes
pub recv: Vec<u8>,
}
/// A future which runs the TLS connection to completion.
///
/// This future must be polled in order for the connection to make progress.
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture {
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
}
impl Future for ConnectionFuture {
type Output = Result<ClosedConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
/// Binds a client connection to the provided socket.
///
/// Returns a connection handle and a future which runs the connection to
/// completion.
///
/// # Errors
///
/// Any connection errors that occur will be returned from the future, not
/// [`TlsConnection`].
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
socket: T,
mut client: ClientConnection,
) -> (TlsConnection, ConnectionFuture) {
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
let conn = TlsConnection::new(tx_sender, rx_receiver);
let fut = async move {
client.start().await?;
let mut notify = client.get_notify().await?;
let (mut server_rx, mut server_tx) = socket.split();
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
let mut rx_buf = [0u8; RX_BUF_SIZE];
let mut handshake_done = false;
let mut client_closed = false;
let mut server_closed = false;
let mut sent = Vec::with_capacity(1024);
let mut recv = Vec::with_capacity(1024);
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
// We don't start writing application data until the handshake is complete.
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
// Runs both the tx and rx halves of the connection to completion.
// This loop does not terminate until the *SERVER* closes the connection and
// we've processed all received data. If an error occurs, the `TlsConnection`
// channels will be closed and the error will be returned from this future.
'conn: loop {
// Write all pending TLS data to the server.
if client.wants_write() && !client_closed {
#[cfg(feature = "tracing")]
trace!("client wants to write");
while client.wants_write() {
let _sent = client.write_tls_async(&mut server_tx).await?;
#[cfg(feature = "tracing")]
trace!("sent {} tls bytes to server", _sent);
}
server_tx.flush().await?;
}
// Forward received plaintext to `TlsConnection`.
while !client.plaintext_is_empty() {
let read = client.read_plaintext(&mut rx_buf)?;
recv.extend(&rx_buf[..read]);
// Ignore if the receiver has hung up.
_ = rx_sender
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
.await;
#[cfg(feature = "tracing")]
trace!("forwarded {} plaintext bytes to conn", read);
}
if !client.is_handshaking() && !handshake_done {
#[cfg(feature = "tracing")]
debug!("handshake complete");
handshake_done = true;
// Start reading application data that needs to be transmitted from the
// `TlsConnection`.
tx_recv_fut = tx_receiver.next().fuse();
}
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
break 'conn;
}
select_biased! {
// Reads TLS data from the server and writes it into the client.
received = &mut rx_tls_fut => {
let received = received?;
#[cfg(feature = "tracing")]
trace!("received {} tls bytes from server", received);
// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
let mut reader = rx_tls_buf[..received].reader();
loop {
processed += client.read_tls(&mut reader)?;
client.process_new_packets().await?;
debug_assert!(processed <= received);
if processed >= received {
break;
}
}
#[cfg(feature = "tracing")]
trace!("processed {} tls bytes from server", processed);
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket.
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
server_closed = true;
client.server_closed().await?;
// Do not read from the socket again.
rx_tls_fut = Fuse::terminated();
} else {
// Reset the read future so next iteration we can read again.
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
}
}
// If we receive None from `TlsConnection`, it has closed, so we
// send a close_notify to the server.
data = &mut tx_recv_fut => {
if let Some(data) = data {
#[cfg(feature = "tracing")]
trace!("writing {} plaintext bytes to client", data.len());
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {
if !server_closed {
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
#[cfg(feature = "tracing")]
warn!("failed to send close_notify to server: {}", e);
}
}
client_closed = true;
tx_recv_fut = Fuse::terminated();
}
}
// Waits for a notification from the backend that it is ready to decrypt data.
_ = &mut notify => {
#[cfg(feature = "tracing")]
trace!("backend is ready to decrypt");
client.process_new_packets().await?;
}
}
}
#[cfg(feature = "tracing")]
debug!("client shutdown");
_ = server_tx.close().await;
tx_receiver.close();
rx_sender.close_channel();
#[cfg(feature = "tracing")]
trace!(
"server close notify: {}, sent: {}, recv: {}",
client.received_close_notify(),
sent.len(),
recv.len()
);
Ok(ClosedConnection { client, sent, recv })
};
#[cfg(feature = "tracing")]
let fut = fut.instrument(debug_span!("tls_connection"));
let fut = ConnectionFuture { fut: Box::pin(fut) };
(conn, fut)
}
async fn send_close_notify(
client: &mut ClientConnection,
server_tx: &mut (impl AsyncWrite + Unpin),
) -> Result<(), ConnectionError> {
#[cfg(feature = "tracing")]
trace!("sending close_notify to server");
client.send_close_notify().await?;
client.process_new_packets().await?;
// Flush all remaining plaintext
while client.wants_write() {
client.write_tls_async(server_tx).await?;
}
server_tx.flush().await?;
Ok(())
}

View File

@@ -0,0 +1,438 @@
use std::{str, sync::Arc};
use core::future::Future;
use futures::{AsyncReadExt, AsyncWriteExt};
use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
SERVER_DOMAIN,
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
client_tls_conn: TlsConnection,
// a handle that must be `.await`ed to get the result of a TLS connection
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
}
// Sets up a TLS connection between client and server and sends a hello message
#[fixture]
async fn set_up_tls() -> TlsFixture {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
client_tls_conn
.write_all(&pad("expecting you to send back hello".to_string()))
.await
.unwrap();
// give the server some time to respond
std::thread::sleep(std::time::Duration::from_millis(10));
let mut plaintext = vec![0u8; 320];
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
let s = str::from_utf8(&plaintext[0..n]).unwrap();
assert_eq!(s, "hello");
TlsFixture {
client_tls_conn,
closed_tls_task,
}
}
// Expect the async tls client wrapped in `hyper::client` to make a successful
// request and receive the expected response
#[tokio::test]
async fn test_hyper_ok() {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
.await
.unwrap();
tokio::spawn(connection);
let request = Request::builder()
.uri(format!("https://{SERVER_DOMAIN}/echo"))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Full::<Bytes>::new("hello".into()))
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Process the response body
response.into_body().collect().await.unwrap().to_bytes();
let _ = server_task.await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us AND close the socket
// after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify_and_socket_close(
set_up_tls: impl Future<Output = TlsFixture>,
) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect to be able to read the data after server closes the socket abruptly
#[rstest]
#[tokio::test]
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to send us a hello message
client_tls_conn
.write_all(&pad("send a hello message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to read some more data
let mut buf = vec![0u8; 10];
let n = client_tls_conn.read(&mut buf).await.unwrap();
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
}
// Expect there to be no error when server DOES NOT send close_notify but just
// closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(!closed_conn.client.received_close_notify());
}
// Expect to register a delay when the server delays closing the socket
#[rstest]
#[tokio::test]
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
client_tls_conn
.write_all(&pad("must_delay_when_closing".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client
client_tls_conn.close().await.unwrap();
use std::time::Instant;
let now = Instant::now();
// this will resolve when the server stops delaying closing the socket
let closed_conn = closed_tls_task.await.unwrap().unwrap();
let elapsed = now.elapsed();
// the elapsed time must be roughly equal to the server's delay
// (give or take timing variations)
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
assert!(!closed_conn.client.received_close_notify());
}
// Expect client to error when server sends a corrupted message
#[rstest]
#[tokio::test]
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send a corrupted message
client_tls_conn
.write_all(&pad("send_corrupted_message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received corrupt message"
);
}
// Expect client to error when server sends a TLS record with a bad MAC
#[rstest]
#[tokio::test]
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_record_with_bad_mac".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"backend error: Decryption error: \"aead::Error\""
);
}
// Expect client to error when server sends a fatal alert
#[rstest]
#[tokio::test]
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_alert".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received fatal alert: BadRecordMac"
);
}
// Expect an error when trying to write data to a connection which server closed
// abruptly
#[rstest]
#[tokio::test]
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to send some more data
let res = client_tls_conn
.write_all(&pad("more data".to_string()))
.await;
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
}
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
fn pad(s: String) -> Vec<u8> {
assert!(s.len() <= APP_RECORD_LENGTH);
let mut buf = vec![0u8; APP_RECORD_LENGTH];
buf[..s.len()].copy_from_slice(s.as_bytes());
buf
}

View File

@@ -227,7 +227,6 @@ impl ConnectionCommon {
/// Signals that the server has closed the connection.
pub async fn server_closed(&mut self) -> Result<(), Error> {
self.common_state.has_seen_eof = true;
self.common_state.backend.server_closed().await?;
Ok(())
}
@@ -458,9 +457,6 @@ impl ConnectionCommon {
return Err(Error::CorruptMessage);
}
// Process outgoing plaintext buffer and encrypt messages.
self.flush_plaintext().await?;
// Process new messages.
while let Some(msg) = self.message_deframer.frames.pop_front() {
// If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -512,22 +508,25 @@ impl ConnectionCommon {
Ok(state)
}
/// Writes plaintext `buf` into an internal buffer. May not fully process the
/// whole buffer and returns the processed length.
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if buf.is_empty() {
// Don't send empty fragments.
return Ok(0);
/// Write buffer into connection.
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state).await;
}
let len = self.sendable_plaintext.append_limited_copy(buf);
Ok(len)
self.common_state.send_some_plaintext(buf).await
}
/// Writes the entire plaintext `buf` into an internal buffer.
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
self.sendable_plaintext.append(buf.to_vec());
Ok(())
/// Write entire buffer into connection.
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
while pos < buf.len() {
pos += self.write_plaintext(&buf[pos..]).await?;
}
self.backend.flush().await?;
while let Some(msg) = self.backend.next_outgoing().await? {
self.queue_tls_message(msg);
}
Ok(pos)
}
/// Read TLS content from `rd`. This method does internal
@@ -691,11 +690,6 @@ impl CommonState {
self.received_plaintext.is_empty()
}
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS
/// handshake.
///
@@ -788,6 +782,15 @@ impl CommonState {
}
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake
// messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections,
@@ -928,6 +931,32 @@ impl CommonState {
self.sendable_tls.write_to_async(wr).await
}
/// Encrypt and send some plaintext `data`. `limit` controls
/// whether the per-connection buffer limits apply.
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
let len = match limit {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true;
self.flush_plaintext().await
@@ -983,14 +1012,15 @@ impl CommonState {
self.sendable_tls.set_limit(limit);
}
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data {
return Ok(());
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_appdata_encrypt(&buf, Limit::No).await?;
self.send_plain(&buf, Limit::No).await?;
}
Ok(())

View File

@@ -35,15 +35,6 @@ impl ChunkVecBuffer {
self.chunks.is_empty()
}
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing
pub(crate) fn len(&self) -> usize {
let mut len = 0;

View File

@@ -247,8 +247,7 @@ async fn servered_client_data_sent() {
let (mut client, mut server) =
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
client.flush_plaintext().await.unwrap();
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -287,7 +286,7 @@ async fn servered_both_data_sent() {
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
do_handshake(&mut client, &mut server).await;
@@ -433,7 +432,7 @@ async fn server_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
server.send_close_notify();
receive(&mut server, &mut client);
@@ -461,8 +460,7 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.flush_plaintext().await.unwrap();
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
client.send_close_notify().await.unwrap();
send(&mut client, &mut server);
@@ -489,7 +487,7 @@ async fn server_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
receive(&mut server, &mut client);
transfer_eof(&mut client);
@@ -520,7 +518,7 @@ async fn client_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
client.process_new_packets().await.unwrap();
send(&mut client, &mut server);
@@ -902,9 +900,20 @@ async fn client_respects_buffer_limit_pre_handshake() {
client.set_buffer_limit(Some(32));
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
client.flush_plaintext().await.unwrap();
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
12
);
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -944,9 +953,20 @@ async fn client_respects_buffer_limit_post_handshake() {
do_handshake(&mut client, &mut server).await;
client.set_buffer_limit(Some(48));
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
client.flush_plaintext().await.unwrap();
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
6
);
send(&mut client, &mut server);
server.process_new_packets().unwrap();
@@ -1191,8 +1211,14 @@ async fn client_complete_io_for_write() {
do_handshake(&mut client, &mut server).await;
client.write_plaintext(b"01234567890123456789").unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let (rdlen, wrlen) = client
@@ -1324,8 +1350,7 @@ async fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
client.write_all_plaintext(b"world").await.unwrap();
{
let mut pipe = ClientSession::new(&mut client);
@@ -1341,8 +1366,7 @@ async fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
client.write_all_plaintext(b"world").await.unwrap();
{
let pipe = ClientSession::new(&mut client);
@@ -1361,9 +1385,7 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 0,
// };
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// client.write_all_plaintext(b"hello").await.unwrap();
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert!(rc.is_err());
@@ -1380,9 +1402,7 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 1,
// };
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// client.write_all_plaintext(b"hello").await.unwrap();
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1880,9 +1900,14 @@ async fn servered_write_for_client_appdata() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
do_handshake(&mut client, &mut server).await;
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.process_new_packets().await.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -1994,10 +2019,11 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
async fn servered_write_for_client_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"0123456789").unwrap();
client.process_new_packets().await.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"0123456789").await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();

View File

@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
tlsn-core = { workspace = true }
tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true }
tlsn-cipher = { workspace = true }
futures-plex = { workspace = true }
serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true }

View File

@@ -22,9 +22,6 @@ use std::sync::LazyLock;
use semver::Version;
// Size for internal buffers.
const BUF_CAP: usize = 8 * 1024;
// Package version.
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet;
use rangeset::set::RangeSet;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
@@ -77,7 +77,7 @@ where
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter_ranges() {
for idx in idx.iter() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,

View File

@@ -1,31 +1,31 @@
//! Prover.
mod client;
mod error;
mod future;
mod prove;
pub mod state;
pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput;
use crate::{
BUF_CAP, Role,
Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
prover::client::{MpcTlsClient, TlsOutput},
tag::verify_tags,
};
use futures::{FutureExt, TryFutureExt};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::{
io::{Read, Write},
sync::Arc,
task::{Context, Poll},
};
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
config::{
prove::ProveConfig,
@@ -36,9 +36,10 @@ use tlsn_core::{
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -70,16 +71,15 @@ impl Prover<state::Initialized> {
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn commit(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsCommitConfig,
socket: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Prover);
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Sends protocol configuration to verifier for compatibility check.
@@ -118,48 +118,47 @@ impl Prover<state::Initialized> {
debug!("mpc-tls setup complete");
let prover = Prover {
Ok(Prover {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
mpc_tls,
keys,
vm,
},
};
Ok(prover)
})
}
}
impl Prover<state::CommitAccepted> {
/// Connects the prover.
/// Connects to the server using the provided socket.
///
/// Returns a connected prover, which can be used to read and write from/to
/// the active TLS connection.
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed and the TLS transcript is
/// committed.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect(
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsClientConfig,
) -> Result<Prover<state::Connected>, ProverError> {
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mux_fut,
mut mux_fut,
mpc_tls,
keys,
vm,
..
} = self.state;
let decrypt = mpc_tls.is_decrypting();
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = config.server_name();
let server_name =
@@ -196,194 +195,102 @@ impl Prover<state::CommitAccepted> {
rustls_config.with_no_client_auth()
};
let client = ClientConnection::new(Arc::new(rustls_config), Box::new(mpc_tls), server_name)
.map_err(ProverError::config)?;
let client = ClientConnection::new(
Arc::new(rustls_config),
Box::new(mpc_ctrl.clone()),
server_name,
)
.map_err(ProverError::config)?;
let span = self.span.clone();
let (conn, conn_fut) = bind_client(socket, client);
let mpc_tls = MpcTlsClient::new(keys, vm, span, client, decrypt);
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Connected {
mpc_duplex,
mux_ctrl,
mux_fut,
server_name: config.server_name().clone(),
tls_client: Box::new(mpc_tls),
output: None,
},
};
Ok(prover)
}
mpc_ctrl.stop().await?;
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
Ok::<_, ProverError>(())
};
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
info!("starting MPC-TLS");
impl Prover<state::Connected> {
/// Returns `true` if the prover wants to read TLS data from the server.
pub fn wants_read_tls(&self) -> bool {
self.state.tls_client.wants_read_tls()
}
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
/// Returns `true` if the prover wants to write TLS data to the server.
pub fn wants_write_tls(&self) -> bool {
self.state.tls_client.wants_write_tls()
}
info!("finished MPC-TLS");
/// Reads TLS data from the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the TLS data from.
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.read_tls(buf)
}
{
let mut vm = vm.try_lock().expect("VM should not be locked");
/// Writes TLS data for the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer to write the TLS data to.
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.write_tls(buf)
}
debug!("finalizing mpc");
/// Returns `true` if the prover wants to read plaintext data.
pub fn wants_read(&self) -> bool {
self.state.tls_client.wants_read()
}
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
/// Returns `true` if the prover wants to write plaintext data.
pub fn wants_write(&self) -> bool {
self.state.tls_client.wants_write()
}
debug!("mpc finalized");
}
/// Reads plaintext data from the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer where the plaintext data gets written to.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.read(buf)
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
/// Writes plaintext data to be sent to the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the plaintext data from.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.write(buf)
}
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
/// Closes the connection from the client side.
pub fn client_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.client_close()
}
/// Closes the connection from the server side.
pub fn server_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.server_close()
}
/// Enables or disables the decryption of data from the server until the
/// server has closed the connection.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
self.state.tls_client.enable_decryption(enable)
}
/// Returns `true` if decryption of TLS traffic from the server is active.
pub fn is_decrypting(&self) -> bool {
self.state.tls_client.is_decrypting()
}
/// Polls the prover to make progress.
///
/// # Arguments
///
/// * `cx` - The async context.
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
let _ = self.state.mux_fut.poll_unpin(cx)?;
match self.state.tls_client.poll(cx)? {
Poll::Ready(output) => {
let _ = self.state.mux_fut.poll_unpin(cx)?;
self.state.output = Some(output);
Poll::Ready(Ok(()))
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
Poll::Pending => Poll::Pending,
}
}
.instrument(span)
});
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex: self.state.mpc_duplex,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
},
};
Ok(prover)
))
}
}
@@ -398,24 +305,6 @@ impl Prover<state::Committed> {
&self.state.transcript
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Proves information to the verifier.
///
/// # Arguments
@@ -478,19 +367,41 @@ impl Prover<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), ProverError> {
let state::Committed {
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
mux_ctrl, mux_fut, ..
} = self.state;
// Wait for the verifier to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
}
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -1,63 +0,0 @@
//! Provides a TLS client.
use crate::mpz::ProverZk;
use mpc_tls::SessionKeys;
use std::task::{Context, Poll};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Enables or disables decryption of TLS traffic sent by the server.
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
/// Returns `true` if decryption of TLS traffic from the server is active.
fn is_decrypting(&self) -> bool;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -1,391 +0,0 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{collections::VecDeque, pin::Pin, sync::Arc, task::Poll};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture =
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
pub(crate) struct MpcTlsClient {
state: State,
decrypt: bool,
cmds: VecDeque<Command>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum Command {
ClientClose,
ServerClose,
Decrypt(bool),
}
enum State {
Start {
inner: Box<InnerState>,
},
Active {
inner: Box<InnerState>,
},
Busy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
CloseActive {
inner: Box<InnerState>,
},
CloseBusy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
tls: ClientConnection,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_stopped: false,
};
let decrypt = tls.backend().is_decrypting();
Self {
state: State::Start {
inner: Box::new(inner),
},
decrypt,
cmds: VecDeque::default(),
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ClientClose);
Ok(())
}
fn server_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ServerClose);
Ok(())
}
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
self.cmds.push_back(Command::Decrypt(enable));
Ok(())
}
fn is_decrypting(&self) -> bool {
self.decrypt
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start { inner } => {
trace!("inner client is starting");
self.state = State::Busy {
fut: Box::pin(inner.start()),
};
self.poll(cx)
}
State::Active { mut inner } => {
trace!("inner client is active");
if !inner.tls.is_handshaking()
&& let Some(cmd) = self.cmds.pop_front()
{
match cmd {
Command::ClientClose => {
self.state = State::Busy {
fut: Box::pin(inner.client_close()),
};
}
Command::ServerClose => {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
Command::Decrypt(enable) => {
inner.tls.backend_mut().enable_decryption(enable)?;
self.decrypt = enable;
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
}
} else {
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
self.poll(cx)
}
State::Busy { mut fut } => {
trace!("inner client is busy");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active { inner };
}
Poll::Pending => self.state = State::Busy { fut },
}
Poll::Pending
}
State::CloseActive { mut inner } => {
trace!("inner client is close active");
if let Some((ctx, transcript)) = inner.tls.backend_mut().finish() {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
} else {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
self.poll(cx)
}
State::CloseBusy { mut fut } => {
trace!("inner client is busy closing");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::CloseActive { inner };
}
Poll::Pending => self.state = State::CloseBusy { fut },
}
Poll::Pending
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
Poll::Pending
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_stopped: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
debug!("sending close notify");
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
self.tls.server_closed().await?;
self.mpc_stopped = true;
debug!("closed connection serverside");
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -49,13 +49,6 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
}
#[derive(Debug)]
@@ -65,7 +58,6 @@ enum ErrorKind {
Zk,
Config,
Commit,
State,
}
impl fmt::Display for ProverError {
@@ -78,7 +70,6 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
}
if let Some(source) = &self.source {
@@ -95,8 +86,8 @@ impl From<std::io::Error> for ProverError {
}
}
impl From<tls_client::Error> for ProverError {
fn from(e: tls_client::Error) -> Self {
impl From<tls_client_async::ConnectionError> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self {
Self::new(ErrorKind::Io, e)
}
}
@@ -124,9 +115,3 @@ impl From<EncodingError> for ProverError {
Self::new(ErrorKind::Commit, e)
}
}
impl From<ProverError> for std::io::Error {
fn from(value: ProverError) -> Self {
Self::other(value)
}
}

View File

@@ -0,0 +1,32 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use tlsn_core::{
ProverOutput,
config::prove::ProveConfig,

View File

@@ -2,7 +2,6 @@
use std::sync::Arc;
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -15,10 +14,6 @@ use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
};
/// Entry state
@@ -29,7 +24,6 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
@@ -39,21 +33,8 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted);
/// State during the MPC-TLS connection.
pub struct Connected {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
pub(crate) output: Option<TlsOutput>,
}
opaque_debug::implement!(Connected);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -71,13 +52,11 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for CommitAccepted {}
impl ProverState for Connected {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Connected {}
impl Sealed for super::Committed {}
}

View File

@@ -12,7 +12,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{Difference, RangeSet, Union};
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
commit.union(reveal).into_set()
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal);
let private = commit.difference(reveal).into_set();
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
commit.union(reveal).into_set()
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
ciphertext,
})
} else {
let private = commit.difference(reveal);
let private = commit.difference(reveal).into_set();
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -175,15 +175,13 @@ fn alloc_plaintext(
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
}
fn alloc_ciphertext<'a>(
@@ -212,15 +210,13 @@ fn alloc_ciphertext<'a>(
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
}
fn alloc_keystream<'a>(
@@ -233,7 +229,7 @@ fn alloc_keystream<'a>(
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter_ranges();
let mut range_iter = ranges.iter();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
@@ -508,7 +504,7 @@ mod tests {
for record in records {
let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter_ranges() {
for mut range in ranges.iter() {
range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len);
if range.start < range.end {

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet;
use rangeset::set::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
@@ -155,7 +155,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
for range in idx.iter() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
@@ -176,7 +176,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
for range in idx.iter() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
@@ -201,7 +201,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
for range in idx.iter() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;

View File

@@ -10,17 +10,16 @@ pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{
BUF_CAP, Role,
Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
};
use futures::TryFutureExt;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use std::io::{Read, Write};
use tlsn_core::{
config::{
prove::ProveRequest,
@@ -69,10 +68,11 @@ impl Verifier<state::Initialized> {
///
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit(self) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Verifier);
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -102,7 +102,6 @@ impl Verifier<state::Initialized> {
config: self.config,
span: self.span,
state: state::CommitStart {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
ctx,
@@ -122,7 +121,6 @@ impl Verifier<state::CommitStart> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -153,7 +151,6 @@ impl Verifier<state::CommitStart> {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mux_fut,
mpc_tls,
@@ -185,24 +182,6 @@ impl Verifier<state::CommitStart> {
Ok(())
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::CommitAccepted> {
@@ -210,7 +189,6 @@ impl Verifier<state::CommitAccepted> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -267,7 +245,6 @@ impl Verifier<state::CommitAccepted> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -277,24 +254,6 @@ impl Verifier<state::CommitAccepted> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::Committed> {
@@ -307,7 +266,6 @@ impl Verifier<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -328,7 +286,6 @@ impl Verifier<state::Committed> {
config: self.config,
span: self.span,
state: state::Verify {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -342,39 +299,17 @@ impl Verifier<state::Committed> {
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Closes the connection with the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), VerifierError> {
let state::Committed {
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
mux_ctrl, mux_fut, ..
} = self.state;
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
@@ -392,7 +327,6 @@ impl Verifier<state::Verify> {
self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -428,7 +362,6 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -446,7 +379,6 @@ impl Verifier<state::Verify> {
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -464,7 +396,6 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -474,22 +405,4 @@ impl Verifier<state::Verify> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}

View File

@@ -3,7 +3,6 @@
use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -26,7 +25,6 @@ opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover.
pub struct CommitStart {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -38,7 +36,6 @@ opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower,
@@ -50,7 +47,6 @@ opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -63,7 +59,6 @@ opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use tlsn_core::{
VerifierOutput,
config::prove::ProveRequest,

View File

@@ -1,5 +1,5 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use rangeset::set::RangeSet;
use tlsn::{
config::{
prove::ProveConfig,
@@ -51,19 +51,11 @@ async fn test() {
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete());
assert_eq!(
partial_transcript
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
partial_transcript.sent_authed().iter().next().unwrap(),
0..10
);
assert_eq!(
partial_transcript
.received_authed()
.iter_ranges()
.next()
.unwrap(),
partial_transcript.received_authed().iter().next().unwrap(),
0..10
);

View File

@@ -23,6 +23,7 @@ no-bundler = ["web-spawn/no-bundler"]
tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
bincode = { workspace = true }

View File

@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
Self {
sent: value.sent_unsafe().to_vec(),
sent_authed: value.sent_authed().iter_ranges().collect(),
sent_authed: value.sent_authed().iter().collect(),
recv: value.received_unsafe().to_vec(),
recv_authed: value.received_authed().iter_ranges().collect(),
recv_authed: value.received_authed().iter().collect(),
}
}
}