mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-11 14:28:03 -05:00
Compare commits
9 Commits
refactor/r
...
plot_py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b76775fc7c | ||
|
|
72041d1f07 | ||
|
|
ac1df8fc75 | ||
|
|
3cb7c5c0b4 | ||
|
|
b41d678829 | ||
|
|
1ebefa27d8 | ||
|
|
4fe5c1defd | ||
|
|
0e8e547300 | ||
|
|
22cc88907a |
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -21,7 +21,7 @@ env:
|
||||
# - https://github.com/privacy-ethereum/mpz/issues/178
|
||||
# 32 seems to be big enough for the foreseeable future
|
||||
RAYON_NUM_THREADS: 32
|
||||
RUST_VERSION: 1.91.1
|
||||
RUST_VERSION: 1.92.0
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
|
||||
2593
Cargo.lock
generated
2593
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
39
Cargo.toml
39
Cargo.toml
@@ -13,6 +13,7 @@ members = [
|
||||
"crates/server-fixture/server",
|
||||
"crates/tls/backend",
|
||||
"crates/tls/client",
|
||||
"crates/tls/client-async",
|
||||
"crates/tls/core",
|
||||
"crates/mpc-tls",
|
||||
"crates/tls/server-fixture",
|
||||
@@ -56,6 +57,7 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
|
||||
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
|
||||
tlsn-tls-backend = { path = "crates/tls/backend" }
|
||||
tlsn-tls-client = { path = "crates/tls/client" }
|
||||
tlsn-tls-client-async = { path = "crates/tls/client-async" }
|
||||
tlsn-tls-core = { path = "crates/tls/core" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
tlsn-harness-core = { path = "crates/harness/core" }
|
||||
@@ -64,28 +66,27 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn = { path = "crates/tlsn" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
|
||||
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0b46dc0" }
|
||||
rangeset = { version = "0.2" }
|
||||
rangeset = { version = "0.4" }
|
||||
serio = { version = "0.2" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
uid-mux = { version = "0.2" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
|
||||
aead = { version = "0.4" }
|
||||
aes = { version = "0.8" }
|
||||
|
||||
@@ -15,7 +15,7 @@ use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use rangeset::{Difference, RangeSet, UnionMut};
|
||||
use rangeset::{ops::Set, set::RangeSet};
|
||||
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
|
||||
|
||||
type Error = DeapError;
|
||||
@@ -210,10 +210,12 @@ where
|
||||
}
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
let slice_range = slice.to_range();
|
||||
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
|
||||
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
for input in input_minus_follower {
|
||||
zk.commit_raw(
|
||||
self.memory_map
|
||||
.try_get(Slice::from_range_unchecked(input))?,
|
||||
@@ -266,7 +268,7 @@ where
|
||||
mpc.mark_private_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
}
|
||||
@@ -282,7 +284,7 @@ where
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
Role::Follower => {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_vm_core::{memory::Slice, VmError};
|
||||
use rangeset::Subset;
|
||||
use rangeset::ops::Set;
|
||||
|
||||
/// A mapping between the memories of the MPC and ZK VMs.
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
@@ -59,5 +59,7 @@ generic-array = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tlsn-core = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-attestation = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-data-fixtures = { workspace = true }
|
||||
webpki-root-certs = { workspace = true }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Proving configuration.
|
||||
|
||||
use rangeset::{RangeSet, ToRangeSet, UnionMut};
|
||||
use rangeset::set::{RangeSet, ToRangeSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
|
||||
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
|
||||
|
||||
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("{")?;
|
||||
for range in self.0.iter_ranges() {
|
||||
for range in self.0.iter() {
|
||||
write!(f, "{}..{}", range.start, range.end)?;
|
||||
if range.end < self.0.end().unwrap_or(0) {
|
||||
f.write_str(", ")?;
|
||||
|
||||
@@ -26,7 +26,11 @@ mod tls;
|
||||
|
||||
use std::{fmt, ops::Range};
|
||||
|
||||
use rangeset::{Difference, IndexRanges, RangeSet, Union};
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Index, Set},
|
||||
set::RangeSet,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::connection::TranscriptLength;
|
||||
@@ -106,8 +110,14 @@ impl Transcript {
|
||||
}
|
||||
|
||||
Some(
|
||||
Subsequence::new(idx.clone(), data.index_ranges(idx))
|
||||
.expect("data is same length as index"),
|
||||
Subsequence::new(
|
||||
idx.clone(),
|
||||
data.index(idx).fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
)
|
||||
.expect("data is same length as index"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -129,11 +139,11 @@ impl Transcript {
|
||||
let mut sent = vec![0; self.sent.len()];
|
||||
let mut received = vec![0; self.received.len()];
|
||||
|
||||
for range in sent_idx.iter_ranges() {
|
||||
for range in sent_idx.iter() {
|
||||
sent[range.clone()].copy_from_slice(&self.sent[range]);
|
||||
}
|
||||
|
||||
for range in recv_idx.iter_ranges() {
|
||||
for range in recv_idx.iter() {
|
||||
received[range.clone()].copy_from_slice(&self.received[range]);
|
||||
}
|
||||
|
||||
@@ -186,12 +196,20 @@ pub struct CompressedPartialTranscript {
|
||||
impl From<PartialTranscript> for CompressedPartialTranscript {
|
||||
fn from(uncompressed: PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent_authed: uncompressed
|
||||
.sent
|
||||
.index_ranges(&uncompressed.sent_authed_idx),
|
||||
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
|
||||
Vec::new(),
|
||||
|mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
},
|
||||
),
|
||||
received_authed: uncompressed
|
||||
.received
|
||||
.index_ranges(&uncompressed.received_authed_idx),
|
||||
.index(&uncompressed.received_authed_idx)
|
||||
.fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
sent_idx: uncompressed.sent_authed_idx,
|
||||
recv_idx: uncompressed.received_authed_idx,
|
||||
sent_total: uncompressed.sent.len(),
|
||||
@@ -207,7 +225,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.sent_idx.iter_ranges() {
|
||||
for range in compressed.sent_idx.iter() {
|
||||
sent[range.clone()]
|
||||
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -215,7 +233,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.recv_idx.iter_ranges() {
|
||||
for range in compressed.recv_idx.iter() {
|
||||
received[range.clone()]
|
||||
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -304,12 +322,16 @@ impl PartialTranscript {
|
||||
|
||||
/// Returns the index of sent data which haven't been authenticated.
|
||||
pub fn sent_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.sent.len()).difference(&self.sent_authed_idx)
|
||||
(0..self.sent.len())
|
||||
.difference(&self.sent_authed_idx)
|
||||
.into_set()
|
||||
}
|
||||
|
||||
/// Returns the index of received data which haven't been authenticated.
|
||||
pub fn received_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.received.len()).difference(&self.received_authed_idx)
|
||||
(0..self.received.len())
|
||||
.difference(&self.received_authed_idx)
|
||||
.into_set()
|
||||
}
|
||||
|
||||
/// Returns an iterator over the authenticated data in the transcript.
|
||||
@@ -319,7 +341,7 @@ impl PartialTranscript {
|
||||
Direction::Received => (&self.received, &self.received_authed_idx),
|
||||
};
|
||||
|
||||
authed.iter().map(|i| data[i])
|
||||
authed.iter_values().map(move |i| data[i])
|
||||
}
|
||||
|
||||
/// Unions the authenticated data of this transcript with another.
|
||||
@@ -339,24 +361,20 @@ impl PartialTranscript {
|
||||
"received data are not the same length"
|
||||
);
|
||||
|
||||
for range in other
|
||||
.sent_authed_idx
|
||||
.difference(&self.sent_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
|
||||
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
|
||||
}
|
||||
|
||||
for range in other
|
||||
.received_authed_idx
|
||||
.difference(&self.received_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.received[range.clone()].copy_from_slice(&other.received[range]);
|
||||
}
|
||||
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
|
||||
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
|
||||
self.received_authed_idx
|
||||
.union_mut(&other.received_authed_idx);
|
||||
}
|
||||
|
||||
/// Unions an authenticated subsequence into this transcript.
|
||||
@@ -368,11 +386,11 @@ impl PartialTranscript {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
seq.copy_to(&mut self.sent);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
|
||||
self.sent_authed_idx.union_mut(&seq.idx);
|
||||
}
|
||||
Direction::Received => {
|
||||
seq.copy_to(&mut self.received);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
|
||||
self.received_authed_idx.union_mut(&seq.idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,10 +401,10 @@ impl PartialTranscript {
|
||||
///
|
||||
/// * `value` - The value to set the unauthenticated bytes to
|
||||
pub fn set_unauthed(&mut self, value: u8) {
|
||||
for range in self.sent_unauthed().iter_ranges() {
|
||||
for range in self.sent_unauthed().iter() {
|
||||
self.sent[range].fill(value);
|
||||
}
|
||||
for range in self.received_unauthed().iter_ranges() {
|
||||
for range in self.received_unauthed().iter() {
|
||||
self.received[range].fill(value);
|
||||
}
|
||||
}
|
||||
@@ -401,13 +419,13 @@ impl PartialTranscript {
|
||||
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
|
||||
self.sent[range].fill(value);
|
||||
for r in range.difference(&self.sent_authed_idx) {
|
||||
self.sent[r].fill(value);
|
||||
}
|
||||
}
|
||||
Direction::Received => {
|
||||
for range in range.difference(&self.received_authed_idx).iter_ranges() {
|
||||
self.received[range].fill(value);
|
||||
for r in range.difference(&self.received_authed_idx) {
|
||||
self.received[r].fill(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -485,7 +503,7 @@ impl Subsequence {
|
||||
/// Panics if the subsequence ranges are out of bounds.
|
||||
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
|
||||
let mut offset = 0;
|
||||
for range in self.idx.iter_ranges() {
|
||||
for range in self.idx.iter() {
|
||||
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
}
|
||||
@@ -610,12 +628,7 @@ mod validation {
|
||||
mut partial_transcript: CompressedPartialTranscriptUnchecked,
|
||||
) {
|
||||
// Change the total to be less than the last range's end bound.
|
||||
let end = partial_transcript
|
||||
.sent_idx
|
||||
.iter_ranges()
|
||||
.next_back()
|
||||
.unwrap()
|
||||
.end;
|
||||
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
|
||||
|
||||
partial_transcript.sent_total = end - 1;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
use rangeset::{ToRangeSet, UnionMut};
|
||||
use rangeset::set::ToRangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -103,7 +103,7 @@ impl EncodingProof {
|
||||
}
|
||||
|
||||
expected_leaf.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
|
||||
}
|
||||
expected_leaf.extend_from_slice(blinder.as_bytes());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bimap::BiMap;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -99,7 +99,7 @@ impl EncodingTree {
|
||||
let blinder: Blinder = rand::random();
|
||||
|
||||
encoding.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
provider
|
||||
.provide_encoding(direction, range, &mut encoding)
|
||||
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
//! Transcript proofs.
|
||||
|
||||
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Cover, Set},
|
||||
set::ToRangeSet,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
@@ -144,7 +148,7 @@ impl TranscriptProof {
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
buffer.extend_from_slice(&plaintext[range]);
|
||||
}
|
||||
|
||||
@@ -366,7 +370,7 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
if idx.is_subset(committed) {
|
||||
self.query_idx.union(&direction, &idx);
|
||||
} else {
|
||||
let missing = idx.difference(committed);
|
||||
let missing = idx.difference(committed).into_set();
|
||||
return Err(TranscriptProofBuilderError::new(
|
||||
BuilderErrorKind::MissingCommitment,
|
||||
format!(
|
||||
@@ -582,7 +586,7 @@ impl fmt::Display for TranscriptProofBuilderError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::prelude::*;
|
||||
use rstest::rstest;
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
|
||||
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
|
||||
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
|
||||
pub const DEFAULT_MEMORY_PROFILE: bool = false;
|
||||
pub const DEFAULT_REVEAL_ALL: bool = false;
|
||||
|
||||
pub const WARM_UP_BENCH: Bench = Bench {
|
||||
group: None,
|
||||
@@ -20,6 +21,7 @@ pub const WARM_UP_BENCH: Bench = Bench {
|
||||
download_size: 4096,
|
||||
defer_decryption: true,
|
||||
memory_profile: false,
|
||||
reveal_all: true,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -79,6 +81,8 @@ pub struct BenchGroupItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -97,6 +101,8 @@ pub struct BenchItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
impl BenchItem {
|
||||
@@ -132,6 +138,10 @@ impl BenchItem {
|
||||
if self.memory_profile.is_none() {
|
||||
self.memory_profile = group.memory_profile;
|
||||
}
|
||||
|
||||
if self.reveal_all.is_none() {
|
||||
self.reveal_all = group.reveal_all;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bench(&self) -> Bench {
|
||||
@@ -145,6 +155,7 @@ impl BenchItem {
|
||||
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
|
||||
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
|
||||
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
|
||||
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,6 +175,8 @@ pub struct Bench {
|
||||
pub defer_decryption: bool,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: bool,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -22,7 +22,10 @@ pub enum CmdOutput {
|
||||
GetTests(Vec<String>),
|
||||
Test(TestOutput),
|
||||
Bench(BenchOutput),
|
||||
Fail { reason: Option<String> },
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
Fail {
|
||||
reason: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -98,14 +98,27 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
// When reveal_all is false (the default), we exclude 1 byte to avoid the
|
||||
// reveal-all optimization and benchmark the realistic ZK authentication path.
|
||||
let reveal_sent_range = if config.reveal_all {
|
||||
0..sent_len
|
||||
} else {
|
||||
0..sent_len.saturating_sub(1)
|
||||
};
|
||||
let reveal_recv_range = if config.reveal_all {
|
||||
0..recv_len
|
||||
} else {
|
||||
0..recv_len.saturating_sub(1)
|
||||
};
|
||||
|
||||
builder
|
||||
.server_identity()
|
||||
.reveal_sent(&(0..sent_len))?
|
||||
.reveal_recv(&(0..recv_len))?;
|
||||
.reveal_sent(&reveal_sent_range)?
|
||||
.reveal_recv(&reveal_recv_range)?;
|
||||
|
||||
let config = builder.build()?;
|
||||
let prove_config = builder.build()?;
|
||||
|
||||
prover.prove(&config).await?;
|
||||
prover.prove(&prove_config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
let time_total = time_start.elapsed().as_millis();
|
||||
|
||||
@@ -7,10 +7,9 @@ publish = false
|
||||
[dependencies]
|
||||
tlsn-harness-core = { workspace = true }
|
||||
# tlsn-server-fixture = { workspace = true }
|
||||
charming = { version = "0.5.1", features = ["ssr"] }
|
||||
csv = "1.3.0"
|
||||
charming = { version = "0.6.0", features = ["ssr"] }
|
||||
clap = { workspace = true, features = ["derive", "env"] }
|
||||
itertools = "0.14.0"
|
||||
polars = { version = "0.44", features = ["csv", "lazy"] }
|
||||
toml = { workspace = true }
|
||||
|
||||
|
||||
|
||||
111
crates/harness/plot/README.md
Normal file
111
crates/harness/plot/README.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# TLSNotary Benchmark Plot Tool
|
||||
|
||||
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
- `<TOML>` - Path to Bench.toml file defining benchmark structure
|
||||
- `<CSV>...` - One or more CSV files with benchmark results
|
||||
|
||||
### Options
|
||||
|
||||
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
|
||||
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
|
||||
- Number of labels must match number of CSV files
|
||||
- `--min-max-band` - Add min/max bands to plots showing variance
|
||||
- `-h, --help` - Print help information
|
||||
|
||||
## Examples
|
||||
|
||||
### Single Dataset
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml results.csv
|
||||
```
|
||||
|
||||
Generates plots from a single benchmark run.
|
||||
|
||||
### Compare Two Runs
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml before.csv after.csv \
|
||||
--labels "Before Optimization" "After Optimization"
|
||||
```
|
||||
|
||||
Overlays two datasets to compare performance improvements.
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
|
||||
--labels "Native" "Browser" "WASM"
|
||||
```
|
||||
|
||||
Compare three different runtime environments.
|
||||
|
||||
### With Min/Max Bands
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml run1.csv run2.csv \
|
||||
--labels "Config A" "Config B" \
|
||||
--min-max-band
|
||||
```
|
||||
|
||||
Shows variance ranges for each dataset.
|
||||
|
||||
## Output Files
|
||||
|
||||
The tool generates two files per benchmark group:
|
||||
|
||||
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
|
||||
- `<output>.svg` - Static SVG image for documentation
|
||||
|
||||
Default output filenames:
|
||||
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
|
||||
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
|
||||
|
||||
## Plot Format
|
||||
|
||||
Each dataset displays:
|
||||
- **Solid line** - Total runtime (preprocessing + online phase)
|
||||
- **Dashed line** - Online phase only
|
||||
- **Shaded area** (optional) - Min/max variance bands
|
||||
|
||||
Different datasets automatically use distinct colors for easy comparison.
|
||||
|
||||
## CSV Format
|
||||
|
||||
Expected columns in each CSV file:
|
||||
- `group` - Benchmark group name (must match TOML)
|
||||
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
|
||||
- `latency` - Network latency in ms (for latency plots)
|
||||
- `time_preprocess` - Preprocessing time in ms
|
||||
- `time_online` - Online phase time in ms
|
||||
- `time_total` - Total runtime in ms
|
||||
|
||||
## TOML Format
|
||||
|
||||
The benchmark TOML file defines groups with either:
|
||||
|
||||
```toml
|
||||
[[group]]
|
||||
name = "my_benchmark"
|
||||
protocol_latency = 50 # Fixed latency for bandwidth plots
|
||||
# OR
|
||||
bandwidth = 10000 # Fixed bandwidth for latency plots
|
||||
```
|
||||
|
||||
All datasets must use the same TOML file to ensure consistent benchmark structure.
|
||||
|
||||
## Tips
|
||||
|
||||
- Use descriptive labels to make plots self-documenting
|
||||
- Keep CSV files from the same benchmark configuration for valid comparisons
|
||||
- Min/max bands are useful for showing stability but can clutter plots with many datasets
|
||||
- Interactive HTML plots support zooming and hovering for detailed values
|
||||
@@ -1,17 +1,18 @@
|
||||
use std::f32;
|
||||
|
||||
use charming::{
|
||||
Chart, HtmlRenderer,
|
||||
Chart, HtmlRenderer, ImageRenderer,
|
||||
component::{Axis, Legend, Title},
|
||||
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
|
||||
element::{
|
||||
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
|
||||
Trigger,
|
||||
},
|
||||
series::Line,
|
||||
theme::Theme,
|
||||
};
|
||||
use clap::Parser;
|
||||
use harness_core::bench::{BenchItems, Measurement};
|
||||
use itertools::Itertools;
|
||||
|
||||
const THEME: Theme = Theme::Default;
|
||||
use harness_core::bench::BenchItems;
|
||||
use polars::prelude::*;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about)]
|
||||
@@ -19,72 +20,131 @@ struct Cli {
|
||||
/// Path to the Bench.toml file with benchmark spec
|
||||
toml: String,
|
||||
|
||||
/// Path to the CSV file with benchmark results
|
||||
csv: String,
|
||||
/// Paths to CSV files with benchmark results (one or more)
|
||||
csv: Vec<String>,
|
||||
|
||||
/// Prover kind: native or browser
|
||||
#[arg(short, long, value_enum, default_value = "native")]
|
||||
prover_kind: ProverKind,
|
||||
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
|
||||
#[arg(short, long, num_args = 0..)]
|
||||
labels: Vec<String>,
|
||||
|
||||
/// Add min/max bands to plots
|
||||
#[arg(long, default_value_t = false)]
|
||||
min_max_band: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ProverKind {
|
||||
Native,
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProverKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProverKind::Native => write!(f, "Native"),
|
||||
ProverKind::Browser => write!(f, "Browser"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
let mut rdr = csv::Reader::from_path(&cli.csv)?;
|
||||
if cli.csv.is_empty() {
|
||||
return Err("At least one CSV file must be provided".into());
|
||||
}
|
||||
|
||||
// Generate labels if not provided
|
||||
let labels: Vec<String> = if cli.labels.is_empty() {
|
||||
cli.csv
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("Dataset {}", i + 1))
|
||||
.collect()
|
||||
} else if cli.labels.len() != cli.csv.len() {
|
||||
return Err(format!(
|
||||
"Number of labels ({}) must match number of CSV files ({})",
|
||||
cli.labels.len(),
|
||||
cli.csv.len()
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
cli.labels.clone()
|
||||
};
|
||||
|
||||
// Load all CSVs and add dataset label
|
||||
let mut dfs = Vec::new();
|
||||
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
|
||||
let mut df = CsvReadOptions::default()
|
||||
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
|
||||
.finish()?;
|
||||
|
||||
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
|
||||
df.with_column(label_series)?;
|
||||
dfs.push(df);
|
||||
}
|
||||
|
||||
// Combine all dataframes
|
||||
let df = dfs
|
||||
.into_iter()
|
||||
.reduce(|acc, df| acc.vstack(&df).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
|
||||
let groups = items.group;
|
||||
|
||||
// Prepare data for plotting.
|
||||
let all_data: Vec<Measurement> = rdr
|
||||
.deserialize::<Measurement>()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
for group in groups {
|
||||
if group.protocol_latency.is_some() {
|
||||
let latency = group.protocol_latency.unwrap();
|
||||
plot_runtime_vs(
|
||||
&all_data,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
|
||||
"runtime_vs_bandwidth.html",
|
||||
"Bandwidth (Mbps)",
|
||||
)?;
|
||||
// Determine which field varies in benches for this group
|
||||
let benches_in_group: Vec<_> = items
|
||||
.bench
|
||||
.iter()
|
||||
.filter(|b| b.group.as_deref() == Some(&group.name))
|
||||
.collect();
|
||||
|
||||
if benches_in_group.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if group.bandwidth.is_some() {
|
||||
let bandwidth = group.bandwidth.unwrap();
|
||||
// Check which field has varying values
|
||||
let bandwidth_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].bandwidth != w[1].bandwidth);
|
||||
let latency_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
|
||||
let download_size_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].download_size != w[1].download_size);
|
||||
|
||||
if download_size_varies {
|
||||
let upload_size = group.upload_size.unwrap_or(1024);
|
||||
plot_runtime_vs(
|
||||
&all_data,
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
|r| r.latency as f32,
|
||||
"download_size",
|
||||
1.0 / 1024.0, // bytes to KB
|
||||
"Runtime vs Response Size",
|
||||
format!("{} bytes upload size", upload_size),
|
||||
"runtime_vs_download_size",
|
||||
"Response Size (KB)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
} else if bandwidth_varies {
|
||||
let latency = group.protocol_latency.unwrap_or(50);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"bandwidth",
|
||||
1.0 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency", latency),
|
||||
"runtime_vs_bandwidth",
|
||||
"Bandwidth (Mbps)",
|
||||
false, // legend on right
|
||||
)?;
|
||||
} else if latency_varies {
|
||||
let bandwidth = group.bandwidth.unwrap_or(1000);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"latency",
|
||||
1.0,
|
||||
"Runtime vs Latency",
|
||||
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
|
||||
"runtime_vs_latency.html",
|
||||
format!("{} bps bandwidth", bandwidth),
|
||||
"runtime_vs_latency",
|
||||
"Latency (ms)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
}
|
||||
}
|
||||
@@ -92,83 +152,51 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct DataPoint {
|
||||
min: f32,
|
||||
mean: f32,
|
||||
max: f32,
|
||||
}
|
||||
|
||||
struct Points {
|
||||
preprocess: DataPoint,
|
||||
online: DataPoint,
|
||||
total: DataPoint,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn plot_runtime_vs<Fx>(
|
||||
all_data: &[Measurement],
|
||||
fn plot_runtime_vs(
|
||||
df: &DataFrame,
|
||||
labels: &[String],
|
||||
show_min_max: bool,
|
||||
group: &str,
|
||||
x_value: Fx,
|
||||
x_col: &str,
|
||||
x_scale: f32,
|
||||
title: &str,
|
||||
subtitle: String,
|
||||
output_file: &str,
|
||||
x_axis_label: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>>
|
||||
where
|
||||
Fx: Fn(&Measurement) -> f32,
|
||||
{
|
||||
fn data_point(values: &[f32]) -> DataPoint {
|
||||
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
|
||||
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
|
||||
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
|
||||
DataPoint { min, mean, max }
|
||||
}
|
||||
legend_left: bool,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
let stats_df = df
|
||||
.clone()
|
||||
.lazy()
|
||||
.filter(col("group").eq(lit(group)))
|
||||
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
|
||||
.with_columns([
|
||||
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
|
||||
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
|
||||
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
|
||||
])
|
||||
.group_by([col("x"), col("dataset_label")])
|
||||
.agg([
|
||||
col("preprocess").min().alias("preprocess_min"),
|
||||
col("preprocess").mean().alias("preprocess_mean"),
|
||||
col("preprocess").max().alias("preprocess_max"),
|
||||
col("online").min().alias("online_min"),
|
||||
col("online").mean().alias("online_mean"),
|
||||
col("online").max().alias("online_max"),
|
||||
col("total").min().alias("total_min"),
|
||||
col("total").mean().alias("total_mean"),
|
||||
col("total").max().alias("total_max"),
|
||||
])
|
||||
.sort(["dataset_label", "x"], Default::default())
|
||||
.collect()?;
|
||||
|
||||
let stats: Vec<(f32, Points)> = all_data
|
||||
.iter()
|
||||
.filter(|r| r.group.as_deref() == Some(group))
|
||||
.map(|r| {
|
||||
(
|
||||
x_value(r),
|
||||
r.time_preprocess as f32 / 1000.0, // ms to s
|
||||
r.time_online as f32 / 1000.0,
|
||||
r.time_total as f32 / 1000.0,
|
||||
)
|
||||
})
|
||||
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
|
||||
.chunk_by(|entry| entry.0)
|
||||
.into_iter()
|
||||
.map(|(x, group)| {
|
||||
let group_vec: Vec<_> = group.collect();
|
||||
let preprocess = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, t, _, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let online = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, t, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let total = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, _, t)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
(
|
||||
x,
|
||||
Points {
|
||||
preprocess,
|
||||
online,
|
||||
total,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
// Build legend entries
|
||||
let mut legend_data = Vec::new();
|
||||
for label in labels {
|
||||
legend_data.push(format!("Total Mean ({})", label));
|
||||
legend_data.push(format!("Online Mean ({})", label));
|
||||
}
|
||||
|
||||
let mut chart = Chart::new()
|
||||
.title(
|
||||
@@ -179,14 +207,6 @@ where
|
||||
.subtext_style(TextStyle::new().font_size(16)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(
|
||||
Legend::new()
|
||||
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
|
||||
.top("80")
|
||||
.right("110")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.name(x_axis_label)
|
||||
@@ -205,73 +225,156 @@ where
|
||||
.name_text_style(TextStyle::new().font_size(21)),
|
||||
);
|
||||
|
||||
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
|
||||
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
|
||||
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
|
||||
// Add legend with conditional positioning
|
||||
let legend = Legend::new()
|
||||
.data(legend_data)
|
||||
.top("80")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10);
|
||||
|
||||
if show_min_max {
|
||||
chart = add_min_max_band(
|
||||
chart,
|
||||
&stats,
|
||||
"Preprocess Min/Max",
|
||||
|p| &p.preprocess,
|
||||
"#ccc",
|
||||
);
|
||||
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
|
||||
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
|
||||
let legend = if legend_left {
|
||||
legend.left("110")
|
||||
} else {
|
||||
legend.right("110")
|
||||
};
|
||||
|
||||
chart = chart.legend(legend);
|
||||
|
||||
// Define colors for each dataset
|
||||
let colors = vec![
|
||||
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
|
||||
];
|
||||
|
||||
for (idx, label) in labels.iter().enumerate() {
|
||||
let color = colors.get(idx % colors.len()).unwrap();
|
||||
|
||||
// Total time - solid line
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Mean ({})", label),
|
||||
"total_mean",
|
||||
false,
|
||||
color,
|
||||
)?;
|
||||
|
||||
// Online time - dashed line (same color as total)
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Online Mean ({})", label),
|
||||
"online_mean",
|
||||
true,
|
||||
color,
|
||||
)?;
|
||||
|
||||
if show_min_max {
|
||||
chart = add_dataset_min_max_band(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Min/Max ({})", label),
|
||||
"total",
|
||||
color,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
// Save the chart as HTML file.
|
||||
// Save the chart as HTML file (no theme)
|
||||
HtmlRenderer::new(title, 1000, 800)
|
||||
.theme(THEME)
|
||||
.save(&chart, output_file)
|
||||
.save(&chart, &format!("{}.html", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with default theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Default)
|
||||
.save(&chart, &format!("{}.svg", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with dark theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Dark)
|
||||
.save(&chart, &format!("{}_dark.svg", output_file))
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
|
||||
fn add_mean_series(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
name: &str,
|
||||
extract: impl Fn(&Points) -> f32,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points)])
|
||||
.collect(),
|
||||
)
|
||||
.symbol_size(6),
|
||||
)
|
||||
fn add_dataset_series(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
series_name: &str,
|
||||
col_name: &str,
|
||||
dashed: bool,
|
||||
color: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let y = filtered.column(col_name)?.f32()?;
|
||||
|
||||
let data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(y.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let mut line = Line::new()
|
||||
.name(series_name)
|
||||
.data(data)
|
||||
.symbol_size(6)
|
||||
.item_style(ItemStyle::new().color(color));
|
||||
|
||||
let mut line_style = LineStyle::new();
|
||||
if dashed {
|
||||
line_style = line_style.type_(LineStyleType::Dashed);
|
||||
}
|
||||
line = line.line_style(line_style.color(color));
|
||||
|
||||
Ok(chart.clone().series(line))
|
||||
}
|
||||
|
||||
fn add_min_max_band(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
fn add_dataset_min_max_band(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
name: &str,
|
||||
extract: impl Fn(&Points) -> &DataPoint,
|
||||
col_prefix: &str,
|
||||
color: &str,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
|
||||
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
|
||||
|
||||
let max_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(max_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let min_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(min_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
|
||||
|
||||
Ok(chart.clone().series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points).max])
|
||||
.chain(
|
||||
stats
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|(x, points)| vec![*x, extract(points).min]),
|
||||
)
|
||||
.collect(),
|
||||
)
|
||||
.data(data)
|
||||
.show_symbol(false)
|
||||
.line_style(LineStyle::new().opacity(0.0))
|
||||
.area_style(AreaStyle::new().opacity(0.3).color(color)),
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
105
crates/harness/plot/data/bandwidth.ipynb
Normal file
105
crates/harness/plot/data/bandwidth.ipynb
Normal file
File diff suppressed because one or more lines are too long
163
crates/harness/plot/data/download.ipynb
Normal file
163
crates/harness/plot/data/download.ipynb
Normal file
File diff suppressed because one or more lines are too long
92
crates/harness/plot/data/latency.ipynb
Normal file
92
crates/harness/plot/data/latency.ipynb
Normal file
File diff suppressed because one or more lines are too long
25
crates/harness/toml/bandwidth.toml
Normal file
25
crates/harness/toml/bandwidth.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
#### Bandwidth ####
|
||||
|
||||
[[group]]
|
||||
name = "bandwidth"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 10
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 50
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 100
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 250
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 1000
|
||||
37
crates/harness/toml/download.toml
Normal file
37
crates/harness/toml/download.toml
Normal file
@@ -0,0 +1,37 @@
|
||||
[[group]]
|
||||
name = "download_size"
|
||||
protocol_latency = 10
|
||||
bandwidth = 200
|
||||
upload-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 1024
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 4096
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 8192
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 16384
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 32768
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 65536
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 131072
|
||||
25
crates/harness/toml/latency.toml
Normal file
25
crates/harness/toml/latency.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
#### Latency ####
|
||||
|
||||
[[group]]
|
||||
name = "latency"
|
||||
bandwidth = 1000
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 10
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 50
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 100
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 200
|
||||
@@ -34,6 +34,7 @@ mpz-share-conversion = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
|
||||
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
|
||||
serio = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
@@ -65,6 +66,7 @@ rand_chacha = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tls-server-fixture = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-util = { workspace = true, features = ["compat"] }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -15,6 +15,13 @@ impl MpcTlsError {
|
||||
Self(ErrorRepr::Peer(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn actor<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Actor(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn state<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
@@ -65,6 +72,8 @@ enum ErrorRepr {
|
||||
Peer(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("I/O error: {0}")]
|
||||
Io(std::io::Error),
|
||||
#[error("actor error: {0}")]
|
||||
Actor(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("state error: {0}")]
|
||||
State(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("allocation error: {0}")]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
mod actor;
|
||||
|
||||
use crate::{
|
||||
error::MpcTlsError,
|
||||
msg::{
|
||||
@@ -12,6 +14,7 @@ use async_trait::async_trait;
|
||||
use hmac_sha256::{MpcPrf, PrfOutput};
|
||||
use ke::KeyExchange;
|
||||
use key_exchange::{self as ke, MpcKeyExchange};
|
||||
use ludi::Context as LudiContext;
|
||||
use mpz_common::{Context, Flush};
|
||||
use mpz_core::{bitvec::BitVec, Block};
|
||||
use mpz_memory_core::DecodeFutureTyped;
|
||||
@@ -47,9 +50,13 @@ use tlsn_core::{
|
||||
};
|
||||
use tracing::{debug, instrument, trace, warn};
|
||||
|
||||
/// Controller for MPC-TLS leader.
|
||||
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
|
||||
|
||||
/// MPC-TLS leader.
|
||||
#[derive(Debug)]
|
||||
pub struct MpcTlsLeader {
|
||||
self_handle: Option<LeaderCtrl>,
|
||||
config: Config,
|
||||
state: State,
|
||||
|
||||
@@ -107,6 +114,7 @@ impl MpcTlsLeader {
|
||||
|
||||
let is_decrypting = !config.defer_decryption;
|
||||
Self {
|
||||
self_handle: None,
|
||||
config,
|
||||
state: State::Init {
|
||||
ctx,
|
||||
@@ -370,42 +378,18 @@ impl MpcTlsLeader {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
/// Defers decryption of any incoming messages.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = enable;
|
||||
|
||||
if enable {
|
||||
self.notifier.set();
|
||||
} else {
|
||||
self.notifier.clear();
|
||||
}
|
||||
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = false;
|
||||
self.notifier.clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns if incoming messages are decrypted.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.is_decrypting
|
||||
}
|
||||
|
||||
/// Returns the context and transcript.
|
||||
///
|
||||
/// Should be called after a successful call to [`Backend::server_closed`].
|
||||
pub fn finish(&mut self) -> Option<(Context, TlsTranscript)> {
|
||||
match self.state.take() {
|
||||
State::Closed {
|
||||
ctx, transcript, ..
|
||||
} => Some((ctx, transcript)),
|
||||
state => {
|
||||
self.state = state;
|
||||
None
|
||||
}
|
||||
}
|
||||
/// Stops the actor.
|
||||
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
|
||||
ctx.stop();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1779
crates/mpc-tls/src/leader/actor.rs
Normal file
1779
crates/mpc-tls/src/leader/actor.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ pub(crate) mod utils;
|
||||
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
|
||||
pub use error::MpcTlsError;
|
||||
pub use follower::MpcTlsFollower;
|
||||
pub use leader::MpcTlsLeader;
|
||||
pub use leader::{LeaderCtrl, MpcTlsLeader};
|
||||
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
|
||||
160
crates/mpc-tls/tests/test.rs
Normal file
160
crates/mpc-tls/tests/test.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::{
|
||||
ideal::rcot::ideal_rcot,
|
||||
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::RootCertStore;
|
||||
use tls_client_async::bind_client;
|
||||
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mpc_tls_test() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Config::builder()
|
||||
.defer_decryption(false)
|
||||
.max_sent(1 << 13)
|
||||
.max_recv_online(1 << 13)
|
||||
.max_recv(1 << 13)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (leader, follower) = build_pair(config);
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(leader_task(leader)),
|
||||
tokio::spawn(follower_task(follower))
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn leader_task(mut leader: MpcTlsLeader) {
|
||||
leader.alloc().unwrap();
|
||||
|
||||
leader.preprocess().await.unwrap();
|
||||
|
||||
let (leader_ctrl, leader_fut) = leader.run();
|
||||
tokio::spawn(async { leader_fut.await.unwrap() });
|
||||
|
||||
let config = tls_client::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(RootCertStore {
|
||||
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
|
||||
})
|
||||
.with_no_client_auth();
|
||||
|
||||
let server_name = SERVER_DOMAIN.try_into().unwrap();
|
||||
|
||||
let client = tls_client::ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(leader_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
|
||||
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: keep-alive\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 48];
|
||||
conn.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.defer_decryption().await.unwrap();
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: close\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
conn.close().await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
conn.read_to_end(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.stop().await.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn follower_task(mut follower: MpcTlsFollower) {
|
||||
follower.alloc().unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower.run().await.unwrap();
|
||||
}
|
||||
|
||||
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let (mut mt_a, mut mt_b) = test_mt_context(8);
|
||||
|
||||
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
|
||||
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
|
||||
|
||||
let delta_a = Delta::new(Block::random(&mut rng));
|
||||
let delta_b = Delta::new(Block::random(&mut rng));
|
||||
|
||||
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
|
||||
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
|
||||
|
||||
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
|
||||
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
|
||||
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
|
||||
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
|
||||
|
||||
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
|
||||
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
|
||||
|
||||
let leader = MpcTlsLeader::new(
|
||||
config.clone(),
|
||||
ctx_a,
|
||||
mpc_a,
|
||||
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
|
||||
rcot_recv_a,
|
||||
);
|
||||
|
||||
let follower = MpcTlsFollower::new(
|
||||
config,
|
||||
ctx_b,
|
||||
mpc_b,
|
||||
rcot_send_b,
|
||||
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
|
||||
);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
39
crates/tls/client-async/Cargo.toml
Normal file
39
crates/tls/client-async/Cargo.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "tlsn-tls-client-async"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "An async TLS client for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "client", "async"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "tls_client_async"
|
||||
|
||||
[features]
|
||||
default = ["tracing"]
|
||||
tracing = ["dep:tracing"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-tls-client = { workspace = true }
|
||||
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["io", "compat"] }
|
||||
tracing = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tls-server-fixture = { workspace = true }
|
||||
|
||||
http-body-util = { workspace = true }
|
||||
hyper = { workspace = true, features = ["client", "http1"] }
|
||||
hyper-util = { workspace = true, features = ["full"] }
|
||||
rstest = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
|
||||
rustls-webpki = { workspace = true }
|
||||
rustls-pki-types = { workspace = true }
|
||||
89
crates/tls/client-async/src/conn.rs
Normal file
89
crates/tls/client-async/src/conn.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use bytes::Bytes;
|
||||
use futures::{
|
||||
channel::mpsc::{Receiver, SendError, Sender},
|
||||
sink::SinkMapErr,
|
||||
AsyncRead, AsyncWrite, SinkExt,
|
||||
};
|
||||
use std::{
|
||||
io::{Error as IoError, ErrorKind as IoErrorKind},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio_util::{
|
||||
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
|
||||
io::{CopyToBytes, SinkWriter, StreamReader},
|
||||
};
|
||||
|
||||
type CompatSinkWriter =
|
||||
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
|
||||
|
||||
/// A TLS connection to a server.
|
||||
///
|
||||
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
|
||||
/// communicate with a server using TLS.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This connection is closed on a best-effort basis if this is dropped. To
|
||||
/// ensure a clean close, you should call
|
||||
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
|
||||
/// connection.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsConnection {
|
||||
/// The data to be transmitted to the server is sent to this sink.
|
||||
tx_sender: CompatSinkWriter,
|
||||
/// The data to be received from the server is received from this stream.
|
||||
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
|
||||
}
|
||||
|
||||
impl TlsConnection {
|
||||
/// Creates a new TLS connection.
|
||||
pub(crate) fn new(
|
||||
tx_sender: Sender<Bytes>,
|
||||
rx_receiver: Receiver<Result<Bytes, IoError>>,
|
||||
) -> Self {
|
||||
fn convert_error(err: SendError) -> IoError {
|
||||
if err.is_disconnected() {
|
||||
IoErrorKind::BrokenPipe.into()
|
||||
} else {
|
||||
IoErrorKind::WouldBlock.into()
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
tx_sender: SinkWriter::new(CopyToBytes::new(
|
||||
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
|
||||
))
|
||||
.compat_write(),
|
||||
rx_receiver: StreamReader::new(rx_receiver).compat(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsConnection {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsConnection {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_close(cx)
|
||||
}
|
||||
}
|
||||
269
crates/tls/client-async/src/lib.rs
Normal file
269
crates/tls/client-async/src/lib.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
//! Provides a TLS client which exposes an async socket.
|
||||
//!
|
||||
//! This library provides the [bind_client] function which attaches a TLS client
|
||||
//! to a socket connection and then exposes a [TlsConnection] object, which
|
||||
//! provides an async socket API for reading and writing cleartext. The TLS
|
||||
//! client will then automatically encrypt and decrypt traffic and forward that
|
||||
//! to the provided socket.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod conn;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{
|
||||
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
|
||||
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
|
||||
};
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::{debug, debug_span, trace, warn, Instrument};
|
||||
|
||||
use tls_client::ClientConnection;
|
||||
|
||||
pub use conn::TlsConnection;
|
||||
|
||||
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
|
||||
/// An error that can occur during a TLS connection.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
TlsError(#[from] tls_client::Error),
|
||||
#[error(transparent)]
|
||||
IOError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Closed connection data.
|
||||
#[derive(Debug)]
|
||||
pub struct ClosedConnection {
|
||||
/// The connection for the client
|
||||
pub client: ClientConnection,
|
||||
/// Sent plaintext bytes
|
||||
pub sent: Vec<u8>,
|
||||
/// Received plaintext bytes
|
||||
pub recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// A future which runs the TLS connection to completion.
|
||||
///
|
||||
/// This future must be polled in order for the connection to make progress.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct ConnectionFuture {
|
||||
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Future for ConnectionFuture {
|
||||
type Output = Result<ClosedConnection, ConnectionError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.fut.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Binds a client connection to the provided socket.
|
||||
///
|
||||
/// Returns a connection handle and a future which runs the connection to
|
||||
/// completion.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Any connection errors that occur will be returned from the future, not
|
||||
/// [`TlsConnection`].
|
||||
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
|
||||
socket: T,
|
||||
mut client: ClientConnection,
|
||||
) -> (TlsConnection, ConnectionFuture) {
|
||||
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
|
||||
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
|
||||
|
||||
let conn = TlsConnection::new(tx_sender, rx_receiver);
|
||||
|
||||
let fut = async move {
|
||||
client.start().await?;
|
||||
let mut notify = client.get_notify().await?;
|
||||
|
||||
let (mut server_rx, mut server_tx) = socket.split();
|
||||
|
||||
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
|
||||
let mut rx_buf = [0u8; RX_BUF_SIZE];
|
||||
|
||||
let mut handshake_done = false;
|
||||
let mut client_closed = false;
|
||||
let mut server_closed = false;
|
||||
|
||||
let mut sent = Vec::with_capacity(1024);
|
||||
let mut recv = Vec::with_capacity(1024);
|
||||
|
||||
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
// We don't start writing application data until the handshake is complete.
|
||||
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
|
||||
|
||||
// Runs both the tx and rx halves of the connection to completion.
|
||||
// This loop does not terminate until the *SERVER* closes the connection and
|
||||
// we've processed all received data. If an error occurs, the `TlsConnection`
|
||||
// channels will be closed and the error will be returned from this future.
|
||||
'conn: loop {
|
||||
// Write all pending TLS data to the server.
|
||||
if client.wants_write() && !client_closed {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("client wants to write");
|
||||
while client.wants_write() {
|
||||
let _sent = client.write_tls_async(&mut server_tx).await?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sent {} tls bytes to server", _sent);
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
}
|
||||
|
||||
// Forward received plaintext to `TlsConnection`.
|
||||
while !client.plaintext_is_empty() {
|
||||
let read = client.read_plaintext(&mut rx_buf)?;
|
||||
recv.extend(&rx_buf[..read]);
|
||||
// Ignore if the receiver has hung up.
|
||||
_ = rx_sender
|
||||
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
|
||||
.await;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("forwarded {} plaintext bytes to conn", read);
|
||||
}
|
||||
|
||||
if !client.is_handshaking() && !handshake_done {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("handshake complete");
|
||||
handshake_done = true;
|
||||
// Start reading application data that needs to be transmitted from the
|
||||
// `TlsConnection`.
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
}
|
||||
|
||||
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
|
||||
break 'conn;
|
||||
}
|
||||
|
||||
select_biased! {
|
||||
// Reads TLS data from the server and writes it into the client.
|
||||
received = &mut rx_tls_fut => {
|
||||
let received = received?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("received {} tls bytes from server", received);
|
||||
|
||||
// Loop until we've processed all the data we received in this read.
|
||||
// Note that we must make one iteration even if `received == 0`.
|
||||
let mut processed = 0;
|
||||
let mut reader = rx_tls_buf[..received].reader();
|
||||
loop {
|
||||
processed += client.read_tls(&mut reader)?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
debug_assert!(processed <= received);
|
||||
if processed >= received {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("processed {} tls bytes from server", processed);
|
||||
|
||||
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
|
||||
// has closed the socket.
|
||||
if received == 0 {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("server closed connection");
|
||||
server_closed = true;
|
||||
client.server_closed().await?;
|
||||
// Do not read from the socket again.
|
||||
rx_tls_fut = Fuse::terminated();
|
||||
} else {
|
||||
// Reset the read future so next iteration we can read again.
|
||||
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
}
|
||||
}
|
||||
// If we receive None from `TlsConnection`, it has closed, so we
|
||||
// send a close_notify to the server.
|
||||
data = &mut tx_recv_fut => {
|
||||
if let Some(data) = data {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("writing {} plaintext bytes to client", data.len());
|
||||
|
||||
sent.extend(&data);
|
||||
client
|
||||
.write_all_plaintext(&data)
|
||||
.await?;
|
||||
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
} else {
|
||||
if !server_closed {
|
||||
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
|
||||
#[cfg(feature = "tracing")]
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
client_closed = true;
|
||||
|
||||
tx_recv_fut = Fuse::terminated();
|
||||
}
|
||||
}
|
||||
// Waits for a notification from the backend that it is ready to decrypt data.
|
||||
_ = &mut notify => {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("backend is ready to decrypt");
|
||||
|
||||
client.process_new_packets().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("client shutdown");
|
||||
|
||||
_ = server_tx.close().await;
|
||||
tx_receiver.close();
|
||||
rx_sender.close_channel();
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!(
|
||||
"server close notify: {}, sent: {}, recv: {}",
|
||||
client.received_close_notify(),
|
||||
sent.len(),
|
||||
recv.len()
|
||||
);
|
||||
|
||||
Ok(ClosedConnection { client, sent, recv })
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
let fut = fut.instrument(debug_span!("tls_connection"));
|
||||
|
||||
let fut = ConnectionFuture { fut: Box::pin(fut) };
|
||||
|
||||
(conn, fut)
|
||||
}
|
||||
|
||||
async fn send_close_notify(
|
||||
client: &mut ClientConnection,
|
||||
server_tx: &mut (impl AsyncWrite + Unpin),
|
||||
) -> Result<(), ConnectionError> {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sending close_notify to server");
|
||||
client.send_close_notify().await?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
// Flush all remaining plaintext
|
||||
while client.wants_write() {
|
||||
client.write_tls_async(server_tx).await?;
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
438
crates/tls/client-async/tests/test.rs
Normal file
438
crates/tls/client-async/tests/test.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
use std::{str, sync::Arc};
|
||||
|
||||
use core::future::Future;
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use http_body_util::{BodyExt as _, Full};
|
||||
use hyper::{body::Bytes, Request, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rstest::{fixture, rstest};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
|
||||
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
|
||||
use tls_server_fixture::{
|
||||
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
|
||||
SERVER_DOMAIN,
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
// An established client TLS connection
|
||||
struct TlsFixture {
|
||||
client_tls_conn: TlsConnection,
|
||||
// a handle that must be `.await`ed to get the result of a TLS connection
|
||||
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
|
||||
}
|
||||
|
||||
// Sets up a TLS connection between client and server and sends a hello message
|
||||
#[fixture]
|
||||
async fn set_up_tls() -> TlsFixture {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("expecting you to send back hello".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// give the server some time to respond
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
|
||||
let mut plaintext = vec![0u8; 320];
|
||||
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
|
||||
let s = str::from_utf8(&plaintext[0..n]).unwrap();
|
||||
|
||||
assert_eq!(s, "hello");
|
||||
|
||||
TlsFixture {
|
||||
client_tls_conn,
|
||||
closed_tls_task,
|
||||
}
|
||||
}
|
||||
|
||||
// Expect the async tls client wrapped in `hyper::client` to make a successful
|
||||
// request and receive the expected response
|
||||
#[tokio::test]
|
||||
async fn test_hyper_ok() {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::spawn(connection);
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{SERVER_DOMAIN}/echo"))
|
||||
.header("Host", SERVER_DOMAIN)
|
||||
.header("Connection", "close")
|
||||
.method("POST")
|
||||
.body(Full::<Bytes>::new("hello".into()))
|
||||
.unwrap();
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Process the response body
|
||||
response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us AND close the socket
|
||||
// after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify_and_socket_close(
|
||||
set_up_tls: impl Future<Output = TlsFixture>,
|
||||
) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to be able to read the data after server closes the socket abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a hello message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send a hello message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to read some more data
|
||||
let mut buf = vec![0u8; 10];
|
||||
let n = client_tls_conn.read(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
|
||||
}
|
||||
|
||||
// Expect there to be no error when server DOES NOT send close_notify but just
|
||||
// closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to register a delay when the server delays closing the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("must_delay_when_closing".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
use std::time::Instant;
|
||||
let now = Instant::now();
|
||||
// this will resolve when the server stops delaying closing the socket
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
let elapsed = now.elapsed();
|
||||
|
||||
// the elapsed time must be roughly equal to the server's delay
|
||||
// (give or take timing variations)
|
||||
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a corrupted message
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send a corrupted message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_corrupted_message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received corrupt message"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a TLS record with a bad MAC
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_record_with_bad_mac".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"backend error: Decryption error: \"aead::Error\""
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a fatal alert
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_alert".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received fatal alert: BadRecordMac"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect an error when trying to write data to a connection which server closed
|
||||
// abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to send some more data
|
||||
let res = client_tls_conn
|
||||
.write_all(&pad("more data".to_string()))
|
||||
.await;
|
||||
|
||||
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
|
||||
}
|
||||
|
||||
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
|
||||
fn pad(s: String) -> Vec<u8> {
|
||||
assert!(s.len() <= APP_RECORD_LENGTH);
|
||||
let mut buf = vec![0u8; APP_RECORD_LENGTH];
|
||||
buf[..s.len()].copy_from_slice(s.as_bytes());
|
||||
buf
|
||||
}
|
||||
@@ -227,7 +227,6 @@ impl ConnectionCommon {
|
||||
|
||||
/// Signals that the server has closed the connection.
|
||||
pub async fn server_closed(&mut self) -> Result<(), Error> {
|
||||
self.common_state.has_seen_eof = true;
|
||||
self.common_state.backend.server_closed().await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -458,9 +457,6 @@ impl ConnectionCommon {
|
||||
return Err(Error::CorruptMessage);
|
||||
}
|
||||
|
||||
// Process outgoing plaintext buffer and encrypt messages.
|
||||
self.flush_plaintext().await?;
|
||||
|
||||
// Process new messages.
|
||||
while let Some(msg) = self.message_deframer.frames.pop_front() {
|
||||
// If we're not decrypting yet, we process it immediately. Otherwise it will be
|
||||
@@ -512,22 +508,25 @@ impl ConnectionCommon {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Writes plaintext `buf` into an internal buffer. May not fully process the
|
||||
/// whole buffer and returns the processed length.
|
||||
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if buf.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
/// Write buffer into connection.
|
||||
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state).await;
|
||||
}
|
||||
|
||||
let len = self.sendable_plaintext.append_limited_copy(buf);
|
||||
Ok(len)
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
}
|
||||
|
||||
/// Writes the entire plaintext `buf` into an internal buffer.
|
||||
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
|
||||
self.sendable_plaintext.append(buf.to_vec());
|
||||
Ok(())
|
||||
/// Write entire buffer into connection.
|
||||
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
let mut pos = 0;
|
||||
while pos < buf.len() {
|
||||
pos += self.write_plaintext(&buf[pos..]).await?;
|
||||
}
|
||||
self.backend.flush().await?;
|
||||
while let Some(msg) = self.backend.next_outgoing().await? {
|
||||
self.queue_tls_message(msg);
|
||||
}
|
||||
Ok(pos)
|
||||
}
|
||||
|
||||
/// Read TLS content from `rd`. This method does internal
|
||||
@@ -691,11 +690,6 @@ impl CommonState {
|
||||
self.received_plaintext.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the buffer for sendable plaintext is full.
|
||||
pub fn sendable_plaintext_is_full(&self) -> bool {
|
||||
self.sendable_plaintext.is_full()
|
||||
}
|
||||
|
||||
/// Returns true if the connection is currently performing the TLS
|
||||
/// handshake.
|
||||
///
|
||||
@@ -788,6 +782,15 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send plaintext application data, fragmenting and
|
||||
/// encrypting it as it goes out.
|
||||
///
|
||||
/// If internal buffers are too small, this function will not accept
|
||||
/// all the data.
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
|
||||
self.send_plain(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
// Changing the keys must not span any fragmented handshake
|
||||
// messages. Otherwise the defragmented messages will have
|
||||
// been protected with two different record layer protections,
|
||||
@@ -928,6 +931,32 @@ impl CommonState {
|
||||
self.sendable_tls.write_to_async(wr).await
|
||||
}
|
||||
|
||||
/// Encrypt and send some plaintext `data`. `limit` controls
|
||||
/// whether the per-connection buffer limits apply.
|
||||
///
|
||||
/// Returns the number of bytes written from `data`: this might
|
||||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
|
||||
if !self.may_send_application_data {
|
||||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
let len = match limit {
|
||||
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
|
||||
Limit::No => self.sendable_plaintext.append(data.to_vec()),
|
||||
};
|
||||
return Ok(len);
|
||||
}
|
||||
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
if data.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_send_application_data = true;
|
||||
self.flush_plaintext().await
|
||||
@@ -983,14 +1012,15 @@ impl CommonState {
|
||||
self.sendable_tls.set_limit(limit);
|
||||
}
|
||||
|
||||
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
|
||||
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
/// Send any buffered plaintext. Plaintext is buffered if
|
||||
/// written during handshake.
|
||||
async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
if !self.may_send_application_data {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while let Some(buf) = self.sendable_plaintext.pop() {
|
||||
self.send_appdata_encrypt(&buf, Limit::No).await?;
|
||||
self.send_plain(&buf, Limit::No).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -35,15 +35,6 @@ impl ChunkVecBuffer {
|
||||
self.chunks.is_empty()
|
||||
}
|
||||
|
||||
/// If the buffer has reached limit.
|
||||
pub(crate) fn is_full(&self) -> bool {
|
||||
if let Some(limit) = self.limit {
|
||||
self.len() >= limit
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// How many bytes we're storing
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
let mut len = 0;
|
||||
|
||||
@@ -247,8 +247,7 @@ async fn servered_client_data_sent() {
|
||||
let (mut client, mut server) =
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -287,7 +286,7 @@ async fn servered_both_data_sent() {
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
@@ -433,7 +432,7 @@ async fn server_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
server.send_close_notify();
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
@@ -461,8 +460,7 @@ async fn client_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
client.send_close_notify().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -489,7 +487,7 @@ async fn server_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
transfer_eof(&mut client);
|
||||
@@ -520,7 +518,7 @@ async fn client_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -902,9 +900,20 @@ async fn client_respects_buffer_limit_pre_handshake() {
|
||||
|
||||
client.set_buffer_limit(Some(32));
|
||||
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
12
|
||||
);
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -944,9 +953,20 @@ async fn client_respects_buffer_limit_post_handshake() {
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
client.set_buffer_limit(Some(48));
|
||||
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
6
|
||||
);
|
||||
|
||||
send(&mut client, &mut server);
|
||||
server.process_new_packets().unwrap();
|
||||
@@ -1191,8 +1211,14 @@ async fn client_complete_io_for_write() {
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let (rdlen, wrlen) = client
|
||||
@@ -1324,8 +1350,7 @@ async fn server_stream_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, mut server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
|
||||
{
|
||||
let mut pipe = ClientSession::new(&mut client);
|
||||
@@ -1341,8 +1366,7 @@ async fn server_streamowned_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
|
||||
{
|
||||
let pipe = ClientSession::new(&mut client);
|
||||
@@ -1361,9 +1385,7 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 0,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert!(rc.is_err());
|
||||
@@ -1380,9 +1402,7 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 1,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert_eq!(format!("{:?}", rc), "Ok(5)");
|
||||
@@ -1880,9 +1900,14 @@ async fn servered_write_for_client_appdata() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
@@ -1994,10 +2019,11 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
|
||||
async fn servered_write_for_client_handshake() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"0123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all_plaintext(b"0123456789").await.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
|
||||
@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-deap = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-mpc-tls = { workspace = true }
|
||||
tlsn-cipher = { workspace = true }
|
||||
|
||||
futures-plex = { workspace = true }
|
||||
serio = { workspace = true, features = ["compat"] }
|
||||
uid-mux = { workspace = true, features = ["serio"] }
|
||||
web-spawn = { workspace = true, optional = true }
|
||||
|
||||
@@ -22,9 +22,6 @@ use std::sync::LazyLock;
|
||||
|
||||
use semver::Version;
|
||||
|
||||
// Size for internal buffers.
|
||||
const BUF_CAP: usize = 8 * 1024;
|
||||
|
||||
// Package version.
|
||||
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
|
||||
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_memory_core::{Vector, binary::U8};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct RangeMap<T> {
|
||||
@@ -77,7 +77,7 @@ where
|
||||
|
||||
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
|
||||
let mut map = Vec::new();
|
||||
for idx in idx.iter_ranges() {
|
||||
for idx in idx.iter() {
|
||||
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
|
||||
Ok(i) => i,
|
||||
Err(0) => return None,
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
//! Prover.
|
||||
|
||||
mod client;
|
||||
mod error;
|
||||
mod future;
|
||||
mod prove;
|
||||
pub mod state;
|
||||
|
||||
pub use error::ProverError;
|
||||
pub use future::ProverFuture;
|
||||
pub use tlsn_core::ProverOutput;
|
||||
|
||||
use crate::{
|
||||
BUF_CAP, Role,
|
||||
Role,
|
||||
context::build_mt_context,
|
||||
mpz::{ProverDeps, build_prover_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
prover::client::{MpcTlsClient, TlsOutput},
|
||||
tag::verify_tags,
|
||||
};
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::LeaderCtrl;
|
||||
use mpz_vm_core::prelude::*;
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use std::{
|
||||
io::{Read, Write},
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{TlsConnection, bind_client};
|
||||
use tlsn_core::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -36,9 +36,10 @@ use tlsn_core::{
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::{TlsTranscript, Transcript},
|
||||
};
|
||||
use tracing::{Span, debug, info_span, instrument};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
use tracing::{Instrument, Span, debug, info, info_span, instrument};
|
||||
|
||||
/// A prover instance.
|
||||
#[derive(Debug)]
|
||||
pub struct Prover<T: state::ProverState = state::Initialized> {
|
||||
@@ -70,16 +71,15 @@ impl Prover<state::Initialized> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS commitment configuration.
|
||||
/// * `socket` - The socket to the TLS verifier.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub async fn commit(
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
config: TlsCommitConfig,
|
||||
socket: S,
|
||||
) -> Result<Prover<state::CommitAccepted>, ProverError> {
|
||||
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Prover);
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
// Sends protocol configuration to verifier for compatibility check.
|
||||
@@ -118,48 +118,47 @@ impl Prover<state::Initialized> {
|
||||
|
||||
debug!("mpc-tls setup complete");
|
||||
|
||||
let prover = Prover {
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitAccepted {
|
||||
mpc_duplex: duplex_a,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
keys,
|
||||
vm,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(prover)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Prover<state::CommitAccepted> {
|
||||
/// Connects the prover.
|
||||
/// Connects to the server using the provided socket.
|
||||
///
|
||||
/// Returns a connected prover, which can be used to read and write from/to
|
||||
/// the active TLS connection.
|
||||
/// Returns a handle to the TLS connection, a future which returns the
|
||||
/// prover once the connection is closed and the TLS transcript is
|
||||
/// committed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS client configuration.
|
||||
/// * `socket` - The socket to the server.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub async fn connect(
|
||||
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
config: TlsClientConfig,
|
||||
) -> Result<Prover<state::Connected>, ProverError> {
|
||||
socket: S,
|
||||
) -> Result<(TlsConnection, ProverFuture), ProverError> {
|
||||
let state::CommitAccepted {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mut mux_fut,
|
||||
mpc_tls,
|
||||
keys,
|
||||
vm,
|
||||
..
|
||||
} = self.state;
|
||||
|
||||
let decrypt = mpc_tls.is_decrypting();
|
||||
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
|
||||
|
||||
let ServerName::Dns(server_name) = config.server_name();
|
||||
let server_name =
|
||||
@@ -196,194 +195,102 @@ impl Prover<state::CommitAccepted> {
|
||||
rustls_config.with_no_client_auth()
|
||||
};
|
||||
|
||||
let client = ClientConnection::new(Arc::new(rustls_config), Box::new(mpc_tls), server_name)
|
||||
.map_err(ProverError::config)?;
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(rustls_config),
|
||||
Box::new(mpc_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.map_err(ProverError::config)?;
|
||||
|
||||
let span = self.span.clone();
|
||||
let (conn, conn_fut) = bind_client(socket, client);
|
||||
|
||||
let mpc_tls = MpcTlsClient::new(keys, vm, span, client, decrypt);
|
||||
let fut = Box::pin({
|
||||
let span = self.span.clone();
|
||||
let mpc_ctrl = mpc_ctrl.clone();
|
||||
async move {
|
||||
let conn_fut = async {
|
||||
mux_fut
|
||||
.poll_with(conn_fut.map_err(ProverError::from))
|
||||
.await?;
|
||||
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Connected {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
server_name: config.server_name().clone(),
|
||||
tls_client: Box::new(mpc_tls),
|
||||
output: None,
|
||||
},
|
||||
};
|
||||
Ok(prover)
|
||||
}
|
||||
mpc_ctrl.stop().await?;
|
||||
|
||||
/// Writes bytes for the verifier into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
Ok::<_, ProverError>(())
|
||||
};
|
||||
|
||||
/// Reads bytes for the prover from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
}
|
||||
info!("starting MPC-TLS");
|
||||
|
||||
impl Prover<state::Connected> {
|
||||
/// Returns `true` if the prover wants to read TLS data from the server.
|
||||
pub fn wants_read_tls(&self) -> bool {
|
||||
self.state.tls_client.wants_read_tls()
|
||||
}
|
||||
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
|
||||
conn_fut,
|
||||
mpc_fut.in_current_span().map_err(ProverError::from)
|
||||
)?;
|
||||
|
||||
/// Returns `true` if the prover wants to write TLS data to the server.
|
||||
pub fn wants_write_tls(&self) -> bool {
|
||||
self.state.tls_client.wants_write_tls()
|
||||
}
|
||||
info!("finished MPC-TLS");
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to read the TLS data from.
|
||||
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.read_tls(buf)
|
||||
}
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to write the TLS data to.
|
||||
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.write_tls(buf)
|
||||
}
|
||||
debug!("finalizing mpc");
|
||||
|
||||
/// Returns `true` if the prover wants to read plaintext data.
|
||||
pub fn wants_read(&self) -> bool {
|
||||
self.state.tls_client.wants_read()
|
||||
}
|
||||
// Finalize DEAP.
|
||||
mux_fut
|
||||
.poll_with(vm.finalize(&mut ctx))
|
||||
.await
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
/// Returns `true` if the prover wants to write plaintext data.
|
||||
pub fn wants_write(&self) -> bool {
|
||||
self.state.tls_client.wants_write()
|
||||
}
|
||||
debug!("mpc finalized");
|
||||
}
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer where the plaintext data gets written to.
|
||||
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.read(buf)
|
||||
}
|
||||
// Pull out ZK VM.
|
||||
let (_, mut vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer to read the plaintext data from.
|
||||
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
|
||||
self.state.tls_client.write(buf)
|
||||
}
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut vm,
|
||||
(keys.server_write_key, keys.server_write_iv),
|
||||
keys.server_write_mac_key,
|
||||
*tls_transcript.version(),
|
||||
tls_transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
/// Writes bytes for the verifier into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
|
||||
/// Reads bytes for the prover from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
/// Closes the connection from the client side.
|
||||
pub fn client_close(&mut self) -> Result<(), ProverError> {
|
||||
self.state.tls_client.client_close()
|
||||
}
|
||||
|
||||
/// Closes the connection from the server side.
|
||||
pub fn server_close(&mut self) -> Result<(), ProverError> {
|
||||
self.state.tls_client.server_close()
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of data from the server until the
|
||||
/// server has closed the connection.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
|
||||
self.state.tls_client.enable_decryption(enable)
|
||||
}
|
||||
|
||||
/// Returns `true` if decryption of TLS traffic from the server is active.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.state.tls_client.is_decrypting()
|
||||
}
|
||||
|
||||
/// Polls the prover to make progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cx` - The async context.
|
||||
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
|
||||
let _ = self.state.mux_fut.poll_unpin(cx)?;
|
||||
|
||||
match self.state.tls_client.poll(cx)? {
|
||||
Poll::Ready(output) => {
|
||||
let _ = self.state.mux_fut.poll_unpin(cx)?;
|
||||
self.state.output = Some(output);
|
||||
Poll::Ready(Ok(()))
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: config.server_name().clone(),
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
})
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
.instrument(span)
|
||||
});
|
||||
|
||||
/// Returns a committed prover after the TLS session has completed.
|
||||
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
|
||||
let TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
} = self.state.output.ok_or(ProverError::state(
|
||||
"prover has not yet closed the connection",
|
||||
))?;
|
||||
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mpc_duplex: self.state.mpc_duplex,
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: self.state.server_name,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
Ok((
|
||||
conn,
|
||||
ProverFuture {
|
||||
fut,
|
||||
ctrl: ProverControl { mpc_ctrl },
|
||||
},
|
||||
};
|
||||
|
||||
Ok(prover)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,24 +305,6 @@ impl Prover<state::Committed> {
|
||||
&self.state.transcript
|
||||
}
|
||||
|
||||
/// Writes bytes for the verifier into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
|
||||
/// Reads bytes for the prover from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
|
||||
/// Proves information to the verifier.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -478,19 +367,41 @@ impl Prover<state::Committed> {
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn close(self) -> Result<(), ProverError> {
|
||||
let state::Committed {
|
||||
mut mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
..
|
||||
mux_ctrl, mux_fut, ..
|
||||
} = self.state;
|
||||
|
||||
// Wait for the verifier to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A controller for the prover.
|
||||
#[derive(Clone)]
|
||||
pub struct ProverControl {
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Defers decryption of data from the server until the server has closed
|
||||
/// the connection.
|
||||
///
|
||||
/// This is a performance optimization which will significantly reduce the
|
||||
/// amount of upload bandwidth used by the prover.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// * The prover may need to close the connection to the server in order for
|
||||
/// it to close the connection on its end. If neither the prover or server
|
||||
/// close the connection this will cause a deadlock.
|
||||
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
|
||||
self.mpc_ctrl
|
||||
.defer_decryption()
|
||||
.await
|
||||
.map_err(ProverError::from)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
//! Provides a TLS client.
|
||||
|
||||
use crate::mpz::ProverZk;
|
||||
use mpc_tls::SessionKeys;
|
||||
use std::task::{Context, Poll};
|
||||
use tlsn_core::transcript::{TlsTranscript, Transcript};
|
||||
|
||||
mod mpc;
|
||||
|
||||
pub(crate) use mpc::MpcTlsClient;
|
||||
|
||||
/// TLS client for MPC and proxy-based TLS implementations.
|
||||
pub(crate) trait TlsClient {
|
||||
type Error: std::error::Error + Send + Sync + Unpin + 'static;
|
||||
|
||||
/// Returns `true` if the client wants to read TLS data from the server.
|
||||
fn wants_read_tls(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write TLS data to the server.
|
||||
fn wants_write_tls(&self) -> bool;
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Returns `true` if the client wants to read plaintext data.
|
||||
fn wants_read(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write plaintext data.
|
||||
fn wants_write(&self) -> bool;
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Client closes the connection.
|
||||
fn client_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Server closes the connection.
|
||||
fn server_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Enables or disables decryption of TLS traffic sent by the server.
|
||||
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns `true` if decryption of TLS traffic from the server is active.
|
||||
fn is_decrypting(&self) -> bool;
|
||||
|
||||
/// Polls the client to make progress.
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
|
||||
}
|
||||
|
||||
/// Output of a TLS session.
|
||||
pub(crate) struct TlsOutput {
|
||||
pub(crate) ctx: mpz_common::Context,
|
||||
pub(crate) vm: ProverZk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript: Transcript,
|
||||
}
|
||||
@@ -1,391 +0,0 @@
|
||||
//! Implementation of an MPC-TLS client.
|
||||
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{TlsClient, TlsOutput},
|
||||
},
|
||||
tag::verify_tags,
|
||||
};
|
||||
use futures::{Future, FutureExt};
|
||||
use mpc_tls::{MpcTlsLeader, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_vm_core::Execute;
|
||||
use std::{collections::VecDeque, pin::Pin, sync::Arc, task::Poll};
|
||||
use tls_client::ClientConnection;
|
||||
use tlsn_core::transcript::TlsTranscript;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{Span, debug, instrument, trace, warn};
|
||||
|
||||
pub(crate) type MpcFuture =
|
||||
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
type FinalizeFuture =
|
||||
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
pub(crate) struct MpcTlsClient {
|
||||
state: State,
|
||||
decrypt: bool,
|
||||
cmds: VecDeque<Command>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum Command {
|
||||
ClientClose,
|
||||
ServerClose,
|
||||
Decrypt(bool),
|
||||
}
|
||||
|
||||
enum State {
|
||||
Start {
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
Active {
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
Busy {
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
CloseActive {
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
CloseBusy {
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
Finalizing {
|
||||
fut: Pin<FinalizeFuture>,
|
||||
},
|
||||
Finished,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl MpcTlsClient {
|
||||
pub(crate) fn new(
|
||||
keys: SessionKeys,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
span: Span,
|
||||
tls: ClientConnection,
|
||||
) -> Self {
|
||||
let inner = InnerState {
|
||||
span,
|
||||
tls,
|
||||
vm,
|
||||
keys,
|
||||
mpc_stopped: false,
|
||||
};
|
||||
let decrypt = tls.backend().is_decrypting();
|
||||
|
||||
Self {
|
||||
state: State::Start {
|
||||
inner: Box::new(inner),
|
||||
},
|
||||
decrypt,
|
||||
cmds: VecDeque::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &mut self.state {
|
||||
Some(&mut inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client(&self) -> Option<&ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &self.state {
|
||||
Some(&inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsClient for MpcTlsClient {
|
||||
type Error = ProverError;
|
||||
|
||||
fn wants_read_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_read()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_write()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_read()
|
||||
{
|
||||
client.read_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_write()
|
||||
{
|
||||
client.write_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_read(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.plaintext_is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.sendable_plaintext_is_full()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.plaintext_is_empty()
|
||||
{
|
||||
client.read_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.sendable_plaintext_is_full()
|
||||
{
|
||||
client.write_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn client_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.cmds.push_back(Command::ClientClose);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn server_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.cmds.push_back(Command::ServerClose);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
|
||||
self.cmds.push_back(Command::Decrypt(enable));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_decrypting(&self) -> bool {
|
||||
self.decrypt
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Start { inner } => {
|
||||
trace!("inner client is starting");
|
||||
self.state = State::Busy {
|
||||
fut: Box::pin(inner.start()),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Active { mut inner } => {
|
||||
trace!("inner client is active");
|
||||
|
||||
if !inner.tls.is_handshaking()
|
||||
&& let Some(cmd) = self.cmds.pop_front()
|
||||
{
|
||||
match cmd {
|
||||
Command::ClientClose => {
|
||||
self.state = State::Busy {
|
||||
fut: Box::pin(inner.client_close()),
|
||||
};
|
||||
}
|
||||
Command::ServerClose => {
|
||||
self.state = State::CloseBusy {
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
}
|
||||
Command::Decrypt(enable) => {
|
||||
inner.tls.backend_mut().enable_decryption(enable)?;
|
||||
self.decrypt = enable;
|
||||
self.state = State::Busy {
|
||||
fut: Box::pin(inner.run()),
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.state = State::Busy {
|
||||
fut: Box::pin(inner.run()),
|
||||
};
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Busy { mut fut } => {
|
||||
trace!("inner client is busy");
|
||||
match fut.as_mut().poll(cx)? {
|
||||
Poll::Ready(inner) => {
|
||||
self.state = State::Active { inner };
|
||||
}
|
||||
Poll::Pending => self.state = State::Busy { fut },
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
State::CloseActive { mut inner } => {
|
||||
trace!("inner client is close active");
|
||||
if let Some((ctx, transcript)) = inner.tls.backend_mut().finish() {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, transcript)),
|
||||
};
|
||||
} else {
|
||||
self.state = State::CloseBusy {
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::CloseBusy { mut fut } => {
|
||||
trace!("inner client is busy closing");
|
||||
match fut.as_mut().poll(cx)? {
|
||||
Poll::Ready(inner) => {
|
||||
self.state = State::CloseActive { inner };
|
||||
}
|
||||
Poll::Pending => self.state = State::CloseBusy { fut },
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
|
||||
Poll::Ready(output) => {
|
||||
let (inner, ctx, tls_transcript) = output?;
|
||||
let InnerState { vm, keys, .. } = inner;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
let (_, vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
let output = TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
};
|
||||
|
||||
self.state = State::Finished;
|
||||
Poll::Ready(Ok(output))
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Finalizing { fut };
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
State::Finished => Poll::Ready(Err(ProverError::state(
|
||||
"mpc tls client polled again in finished state",
|
||||
))),
|
||||
State::Error => {
|
||||
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerState {
|
||||
span: Span,
|
||||
tls: ClientConnection,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
keys: SessionKeys,
|
||||
mpc_stopped: bool,
|
||||
}
|
||||
|
||||
impl InnerState {
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.start().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
|
||||
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
debug!("sending close notify");
|
||||
if let Err(e) = self.tls.send_close_notify().await {
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
|
||||
self.tls.server_closed().await?;
|
||||
self.mpc_stopped = true;
|
||||
debug!("closed connection serverside");
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn finalize(
|
||||
self,
|
||||
mut ctx: Context,
|
||||
transcript: TlsTranscript,
|
||||
) -> Result<(Self, Context, TlsTranscript), ProverError> {
|
||||
{
|
||||
let mut vm = self.vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
// Finalize DEAP.
|
||||
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
|
||||
// Pull out ZK VM.
|
||||
let mut zk = vm.zk();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut *zk,
|
||||
(self.keys.server_write_key, self.keys.server_write_iv),
|
||||
self.keys.server_write_mac_key,
|
||||
*transcript.version(),
|
||||
transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
debug!("verified tags from server");
|
||||
|
||||
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
|
||||
}
|
||||
|
||||
debug!("MPC-TLS done");
|
||||
Ok((self, ctx, transcript))
|
||||
}
|
||||
}
|
||||
@@ -49,13 +49,6 @@ impl ProverError {
|
||||
{
|
||||
Self::new(ErrorKind::Commit, source)
|
||||
}
|
||||
|
||||
pub(crate) fn state<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::State, source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -65,7 +58,6 @@ enum ErrorKind {
|
||||
Zk,
|
||||
Config,
|
||||
Commit,
|
||||
State,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProverError {
|
||||
@@ -78,7 +70,6 @@ impl fmt::Display for ProverError {
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::State => f.write_str("state error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
@@ -95,8 +86,8 @@ impl From<std::io::Error> for ProverError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tls_client::Error> for ProverError {
|
||||
fn from(e: tls_client::Error) -> Self {
|
||||
impl From<tls_client_async::ConnectionError> for ProverError {
|
||||
fn from(e: tls_client_async::ConnectionError) -> Self {
|
||||
Self::new(ErrorKind::Io, e)
|
||||
}
|
||||
}
|
||||
@@ -124,9 +115,3 @@ impl From<EncodingError> for ProverError {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProverError> for std::io::Error {
|
||||
fn from(value: ProverError) -> Self {
|
||||
Self::other(value)
|
||||
}
|
||||
}
|
||||
|
||||
32
crates/tlsn/src/prover/future.rs
Normal file
32
crates/tlsn/src/prover/future.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
//! This module collects futures which are used by the [Prover].
|
||||
|
||||
use super::{Prover, ProverControl, ProverError, state};
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Prover future which must be polled for the TLS connection to make progress.
|
||||
pub struct ProverFuture {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fut: Pin<
|
||||
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
|
||||
>,
|
||||
pub(crate) ctrl: ProverControl,
|
||||
}
|
||||
|
||||
impl ProverFuture {
|
||||
/// Returns a controller for the prover for advanced functionality.
|
||||
pub fn control(&self) -> ProverControl {
|
||||
self.ctrl.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ProverFuture {
|
||||
type Output = Result<Prover<state::Committed>, ProverError>;
|
||||
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
self.fut.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
ProverOutput,
|
||||
config::prove::ProveConfig,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_plex::DuplexStream;
|
||||
use mpc_tls::{MpcTlsLeader, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use tlsn_core::{
|
||||
@@ -15,10 +14,6 @@ use tokio::sync::Mutex;
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
mux::{MuxControl, MuxFuture},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{TlsClient, TlsOutput},
|
||||
},
|
||||
};
|
||||
|
||||
/// Entry state
|
||||
@@ -29,7 +24,6 @@ opaque_debug::implement!(Initialized);
|
||||
/// State after the verifier has accepted the proposed TLS commitment protocol
|
||||
/// configuration and preprocessing has completed.
|
||||
pub struct CommitAccepted {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) mpc_tls: MpcTlsLeader,
|
||||
@@ -39,21 +33,8 @@ pub struct CommitAccepted {
|
||||
|
||||
opaque_debug::implement!(CommitAccepted);
|
||||
|
||||
/// State during the MPC-TLS connection.
|
||||
pub struct Connected {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) server_name: ServerName,
|
||||
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
|
||||
pub(crate) output: Option<TlsOutput>,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Connected);
|
||||
|
||||
/// State after the TLS transcript has been committed.
|
||||
pub struct Committed {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -71,13 +52,11 @@ pub trait ProverState: sealed::Sealed {}
|
||||
|
||||
impl ProverState for Initialized {}
|
||||
impl ProverState for CommitAccepted {}
|
||||
impl ProverState for Connected {}
|
||||
impl ProverState for Committed {}
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for super::Initialized {}
|
||||
impl Sealed for super::CommitAccepted {}
|
||||
impl Sealed for super::Connected {}
|
||||
impl Sealed for super::Committed {}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Call, CallableExt, Vm};
|
||||
use rangeset::{Difference, RangeSet, Union};
|
||||
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
|
||||
use tlsn_core::transcript::Record;
|
||||
|
||||
use crate::transcript_internal::ReferenceMap;
|
||||
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal)
|
||||
commit.union(reveal).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
} else {
|
||||
let private = commit.difference(reveal);
|
||||
let private = commit.difference(reveal).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.expect("all ranges are allocated")
|
||||
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal)
|
||||
commit.union(reveal).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
ciphertext,
|
||||
})
|
||||
} else {
|
||||
let private = commit.difference(reveal);
|
||||
let private = commit.difference(reveal).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.expect("all ranges are allocated")
|
||||
@@ -175,15 +175,13 @@ fn alloc_plaintext(
|
||||
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
|
||||
|
||||
let mut pos = 0;
|
||||
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
|
||||
move |range| {
|
||||
let chunk = plaintext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
},
|
||||
)))
|
||||
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
|
||||
let chunk = plaintext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
})))
|
||||
}
|
||||
|
||||
fn alloc_ciphertext<'a>(
|
||||
@@ -212,15 +210,13 @@ fn alloc_ciphertext<'a>(
|
||||
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
|
||||
|
||||
let mut pos = 0;
|
||||
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
|
||||
move |range| {
|
||||
let chunk = ciphertext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
},
|
||||
)))
|
||||
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
|
||||
let chunk = ciphertext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
})))
|
||||
}
|
||||
|
||||
fn alloc_keystream<'a>(
|
||||
@@ -233,7 +229,7 @@ fn alloc_keystream<'a>(
|
||||
let mut keystream = Vec::new();
|
||||
|
||||
let mut pos = 0;
|
||||
let mut range_iter = ranges.iter_ranges();
|
||||
let mut range_iter = ranges.iter();
|
||||
let mut current_range = range_iter.next();
|
||||
for record in records {
|
||||
let mut explicit_nonce = None;
|
||||
@@ -508,7 +504,7 @@ mod tests {
|
||||
for record in records {
|
||||
let mut record_keystream = vec![0u8; record.len];
|
||||
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
|
||||
for mut range in ranges.iter_ranges() {
|
||||
for mut range in ranges.iter() {
|
||||
range.start = range.start.max(pos);
|
||||
range.end = range.end.min(pos + record.len);
|
||||
if range.start < range.end {
|
||||
|
||||
@@ -9,7 +9,7 @@ use mpz_memory_core::{
|
||||
correlated::{Delta, Key, Mac},
|
||||
};
|
||||
use rand::Rng;
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use tlsn_core::{
|
||||
|
||||
@@ -9,7 +9,7 @@ use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Vm, VmError, prelude::*};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
hash::{Blinder, Hash, HashAlgId, TypedHash},
|
||||
transcript::{
|
||||
@@ -155,7 +155,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
|
||||
}
|
||||
|
||||
@@ -176,7 +176,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher
|
||||
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
|
||||
.map_err(HashCommitError::hasher)?;
|
||||
@@ -201,7 +201,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher
|
||||
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
|
||||
.map_err(HashCommitError::hasher)?;
|
||||
|
||||
@@ -10,17 +10,16 @@ pub use error::VerifierError;
|
||||
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
|
||||
|
||||
use crate::{
|
||||
BUF_CAP, Role,
|
||||
Role,
|
||||
context::build_mt_context,
|
||||
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpz_vm_core::prelude::*;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use std::io::{Read, Write};
|
||||
use tlsn_core::{
|
||||
config::{
|
||||
prove::ProveRequest,
|
||||
@@ -69,10 +68,11 @@ impl Verifier<state::Initialized> {
|
||||
///
|
||||
/// * `socket` - The socket to the prover.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn commit(self) -> Result<Verifier<state::CommitStart>, VerifierError> {
|
||||
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Verifier);
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
socket: S,
|
||||
) -> Result<Verifier<state::CommitStart>, VerifierError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
@@ -102,7 +102,6 @@ impl Verifier<state::Initialized> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitStart {
|
||||
mpc_duplex: duplex_a,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -122,7 +121,6 @@ impl Verifier<state::CommitStart> {
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
|
||||
let state::CommitStart {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -153,7 +151,6 @@ impl Verifier<state::CommitStart> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitAccepted {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
@@ -185,24 +182,6 @@ impl Verifier<state::CommitStart> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes bytes for the prover into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
|
||||
/// Reads bytes for the verifier from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Verifier<state::CommitAccepted> {
|
||||
@@ -210,7 +189,6 @@ impl Verifier<state::CommitAccepted> {
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let state::CommitAccepted {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mpc_tls,
|
||||
@@ -267,7 +245,6 @@ impl Verifier<state::CommitAccepted> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -277,24 +254,6 @@ impl Verifier<state::CommitAccepted> {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Writes bytes for the prover into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
|
||||
/// Reads bytes for the verifier from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Verifier<state::Committed> {
|
||||
@@ -307,7 +266,6 @@ impl Verifier<state::Committed> {
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
|
||||
let state::Committed {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -328,7 +286,6 @@ impl Verifier<state::Committed> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Verify {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -342,39 +299,17 @@ impl Verifier<state::Committed> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Writes bytes for the prover into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
|
||||
/// Reads bytes for the verifier from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
|
||||
/// Closes the connection with the prover.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn close(self) -> Result<(), VerifierError> {
|
||||
let state::Committed {
|
||||
mut mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
..
|
||||
mux_ctrl, mux_fut, ..
|
||||
} = self.state;
|
||||
|
||||
// Wait for the prover to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -392,7 +327,6 @@ impl Verifier<state::Verify> {
|
||||
self,
|
||||
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
|
||||
let state::Verify {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -428,7 +362,6 @@ impl Verifier<state::Verify> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -446,7 +379,6 @@ impl Verifier<state::Verify> {
|
||||
msg: Option<&str>,
|
||||
) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let state::Verify {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -464,7 +396,6 @@ impl Verifier<state::Verify> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mpc_duplex,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -474,22 +405,4 @@ impl Verifier<state::Verify> {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Writes bytes for the prover into a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.read(buf)
|
||||
}
|
||||
|
||||
/// Reads bytes for the verifier from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - The buffer.
|
||||
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.state.mpc_duplex.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::mux::{MuxControl, MuxFuture};
|
||||
use futures_plex::DuplexStream;
|
||||
use mpc_tls::{MpcTlsFollower, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use tlsn_core::{
|
||||
@@ -26,7 +25,6 @@ opaque_debug::implement!(Initialized);
|
||||
|
||||
/// State after receiving protocol configuration from the prover.
|
||||
pub struct CommitStart {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -38,7 +36,6 @@ opaque_debug::implement!(CommitStart);
|
||||
/// State after accepting the proposed TLS commitment protocol configuration and
|
||||
/// performing preprocessing.
|
||||
pub struct CommitAccepted {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) mpc_tls: MpcTlsFollower,
|
||||
@@ -50,7 +47,6 @@ opaque_debug::implement!(CommitAccepted);
|
||||
|
||||
/// State after the TLS transcript has been committed.
|
||||
pub struct Committed {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -63,7 +59,6 @@ opaque_debug::implement!(Committed);
|
||||
|
||||
/// State after receiving a proving request.
|
||||
pub struct Verify {
|
||||
pub(crate) mpc_duplex: DuplexStream,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
|
||||
@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
VerifierOutput,
|
||||
config::prove::ProveRequest,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -51,19 +51,11 @@ async fn test() {
|
||||
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
|
||||
assert!(!partial_transcript.is_complete());
|
||||
assert_eq!(
|
||||
partial_transcript
|
||||
.sent_authed()
|
||||
.iter_ranges()
|
||||
.next()
|
||||
.unwrap(),
|
||||
partial_transcript.sent_authed().iter().next().unwrap(),
|
||||
0..10
|
||||
);
|
||||
assert_eq!(
|
||||
partial_transcript
|
||||
.received_authed()
|
||||
.iter_ranges()
|
||||
.next()
|
||||
.unwrap(),
|
||||
partial_transcript.received_authed().iter().next().unwrap(),
|
||||
0..10
|
||||
);
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ no-bundler = ["web-spawn/no-bundler"]
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
|
||||
bincode = { workspace = true }
|
||||
|
||||
@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
|
||||
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent: value.sent_unsafe().to_vec(),
|
||||
sent_authed: value.sent_authed().iter_ranges().collect(),
|
||||
sent_authed: value.sent_authed().iter().collect(),
|
||||
recv: value.received_unsafe().to_vec(),
|
||||
recv_authed: value.received_authed().iter_ranges().collect(),
|
||||
recv_authed: value.received_authed().iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user