Compare commits

..

4 Commits

Author SHA1 Message Date
th4s
4727df6fd4 WIP: removing actor... 2025-12-08 14:26:48 +01:00
th4s
a0c7d469f6 feat(tlsn): implement sans-io api for prover and verifier 2025-12-08 12:23:05 +01:00
th4s
0bf4a857b9 refactor: sans-io TLS IO (#1036)
* refactor: add tls-client trait

* cleanup

* fix clippy

* add start state

* assert that mpc future is pending
2025-12-01 18:55:10 +01:00
th4s
4d449b2a1d refactor(tls-client): make tls_client compatible with a synchronous API (#1027)
* refactor(tls-client): make `write_plaintext` sync and remove async api

* restore `complete_io`

* do not potentially block in `write_all_plaintext`
2025-12-01 18:55:10 +01:00
75 changed files with 3521 additions and 4460 deletions

View File

@@ -21,7 +21,7 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.92.0
RUST_VERSION: 1.91.1
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:

1099
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,6 @@ members = [
"crates/server-fixture/server",
"crates/tls/backend",
"crates/tls/client",
"crates/tls/client-async",
"crates/tls/core",
"crates/mpc-tls",
"crates/tls/server-fixture",
@@ -57,7 +56,6 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-harness-core = { path = "crates/harness/core" }
@@ -66,27 +64,28 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
rangeset = { version = "0.4" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0b46dc0" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
aead = { version = "0.4" }
aes = { version = "0.8" }

View File

@@ -27,7 +27,6 @@ alloy-primitives = { version = "1.3.1", default-features = false }
alloy-signer = { version = "1.0", default-features = false }
alloy-signer-local = { version = "1.0", default-features = false }
rand06-compat = { workspace = true }
rangeset = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }

View File

@@ -5,7 +5,7 @@ use rand::{Rng, rng};
use tlsn_core::{
connection::{ConnectionInfo, ServerEphemKey},
hash::HashAlgId,
transcript::TranscriptCommitment,
transcript::{TranscriptCommitment, encoding::EncoderSecret},
};
use crate::{
@@ -25,6 +25,7 @@ pub struct Sign {
connection_info: Option<ConnectionInfo>,
server_ephemeral_key: Option<ServerEphemKey>,
cert_commitment: ServerCertCommitment,
encoder_secret: Option<EncoderSecret>,
extensions: Vec<Extension>,
transcript_commitments: Vec<TranscriptCommitment>,
}
@@ -86,6 +87,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
connection_info: None,
server_ephemeral_key: None,
cert_commitment,
encoder_secret: None,
transcript_commitments: Vec::new(),
extensions,
},
@@ -106,6 +108,12 @@ impl AttestationBuilder<'_, Sign> {
self
}
/// Sets the secret for encoding commitments.
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
self.state.encoder_secret = Some(secret);
self
}
/// Adds an extension to the attestation.
pub fn extension(&mut self, extension: Extension) -> &mut Self {
self.state.extensions.push(extension);
@@ -129,6 +137,7 @@ impl AttestationBuilder<'_, Sign> {
connection_info,
server_ephemeral_key,
cert_commitment,
encoder_secret,
extensions,
transcript_commitments,
} = self.state;
@@ -159,6 +168,7 @@ impl AttestationBuilder<'_, Sign> {
AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
})?),
cert_commitment: field_id.next(cert_commitment),
encoder_secret: encoder_secret.map(|secret| field_id.next(secret)),
extensions: extensions
.into_iter()
.map(|extension| field_id.next(extension))
@@ -243,7 +253,8 @@ mod test {
use rstest::{fixture, rstest};
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
fixtures::{ConnectionFixture, encoding_provider},
hash::Blake3,
transcript::Transcript,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -274,7 +285,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_config = AttestationConfig::builder()
.supported_signature_algs([SignatureAlgId::SECP256R1])
@@ -293,7 +310,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_config = AttestationConfig::builder()
.supported_signature_algs([SignatureAlgId::SECP256K1])
@@ -313,7 +336,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -334,8 +363,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -359,8 +393,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -393,7 +432,9 @@ mod test {
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
vec![Extension {
id: b"foo".to_vec(),
value: b"bar".to_vec(),
@@ -420,7 +461,9 @@ mod test {
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
vec![Extension {
id: b"foo".to_vec(),
value: b"bar".to_vec(),

View File

@@ -2,7 +2,11 @@
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
transcript::{Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment},
hash::HashAlgorithm,
transcript::{
Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
encoding::{EncodingProvider, EncodingTree},
},
};
use crate::{
@@ -17,13 +21,16 @@ use crate::{
/// A Request fixture used for testing.
#[allow(missing_docs)]
pub struct RequestFixture {
pub encoding_tree: EncodingTree,
pub request: Request,
}
/// Returns a request fixture for testing.
pub fn request_fixture(
transcript: Transcript,
encodings_provider: impl EncodingProvider,
connection: ConnectionFixture,
encoding_hasher: impl HashAlgorithm,
extensions: Vec<Extension>,
) -> RequestFixture {
let provider = CryptoProvider::default();
@@ -43,9 +50,15 @@ pub fn request_fixture(
.unwrap();
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let mut builder = RequestConfig::builder();
// Prover constructs encoding tree.
let encoding_tree = EncodingTree::new(
&encoding_hasher,
transcripts_commitment_config.iter_encoding(),
&encodings_provider,
)
.unwrap();
builder.transcript_commit(transcripts_commitment_config);
let mut builder = RequestConfig::builder();
for extension in extensions {
builder.extension(extension);
@@ -61,7 +74,10 @@ pub fn request_fixture(
let (request, _) = request_builder.build(&provider).unwrap();
RequestFixture { request }
RequestFixture {
encoding_tree,
request,
}
}
/// Returns an attestation fixture for testing.

View File

@@ -79,6 +79,8 @@
//!
//! // Specify all the transcript commitments we want to make.
//! builder
//! // Use BLAKE3 for encoding commitments.
//! .encoding_hash_alg(HashAlgId::BLAKE3)
//! // Commit to all sent data.
//! .commit_sent(&(0..sent_len))?
//! // Commit to the first 10 bytes of sent data.
@@ -127,7 +129,7 @@
//!
//! ```no_run
//! # use tlsn_attestation::{Attestation, CryptoProvider, Secrets, presentation::Presentation};
//! # use tlsn_core::transcript::Direction;
//! # use tlsn_core::transcript::{TranscriptCommitmentKind, Direction};
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! # let attestation: Attestation = unimplemented!();
//! # let secrets: Secrets = unimplemented!();
@@ -138,6 +140,8 @@
//! let mut builder = secrets.transcript_proof_builder();
//!
//! builder
//! // Use transcript encoding commitments.
//! .commitment_kinds(&[TranscriptCommitmentKind::Encoding])
//! // Disclose the first 10 bytes of the sent data.
//! .reveal(&(0..10), Direction::Sent)?
//! // Disclose all of the received data.
@@ -215,7 +219,7 @@ use tlsn_core::{
connection::{ConnectionInfo, ServerEphemKey},
hash::{Hash, HashAlgorithm, TypedHash},
merkle::MerkleTree,
transcript::TranscriptCommitment,
transcript::{TranscriptCommitment, encoding::EncoderSecret},
};
use crate::{
@@ -297,6 +301,8 @@ pub enum FieldKind {
ServerEphemKey = 0x02,
/// Server identity commitment.
ServerIdentityCommitment = 0x03,
/// Encoding commitment.
EncodingCommitment = 0x04,
/// Plaintext hash commitment.
PlaintextHash = 0x05,
}
@@ -321,6 +327,7 @@ pub struct Body {
connection_info: Field<ConnectionInfo>,
server_ephemeral_key: Field<ServerEphemKey>,
cert_commitment: Field<ServerCertCommitment>,
encoder_secret: Option<Field<EncoderSecret>>,
extensions: Vec<Field<Extension>>,
transcript_commitments: Vec<Field<TranscriptCommitment>>,
}
@@ -366,6 +373,7 @@ impl Body {
connection_info: conn_info,
server_ephemeral_key,
cert_commitment,
encoder_secret,
extensions,
transcript_commitments,
} = self;
@@ -383,6 +391,13 @@ impl Body {
),
];
if let Some(encoder_secret) = encoder_secret {
fields.push((
encoder_secret.id,
hasher.hash_separated(&encoder_secret.data),
));
}
for field in extensions.iter() {
fields.push((field.id, hasher.hash_separated(&field.data)));
}

View File

@@ -91,6 +91,11 @@ impl Presentation {
transcript.verify_with_provider(
&provider.hash,
&attestation.body.connection_info().transcript_length,
attestation
.body
.encoder_secret
.as_ref()
.map(|field| &field.data),
attestation.body.transcript_commitments(),
)
})

View File

@@ -144,7 +144,9 @@ impl std::fmt::Display for ErrorKind {
#[cfg(test)]
mod test {
use tlsn_core::{
connection::TranscriptLength, fixtures::ConnectionFixture, hash::HashAlgId,
connection::TranscriptLength,
fixtures::{ConnectionFixture, encoding_provider},
hash::{Blake3, HashAlgId},
transcript::Transcript,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -162,8 +164,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -178,8 +185,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -197,8 +209,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -216,8 +233,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -243,8 +265,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -262,8 +289,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);

View File

@@ -49,4 +49,6 @@ impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
impl_domain_separator!(tlsn_core::connection::CertBinding);
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncoderSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);
impl_domain_separator!(tlsn_core::transcript::hash::PlaintextHash);

View File

@@ -1,5 +1,3 @@
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::set::RangeSet;
use tlsn_attestation::{
Attestation, AttestationConfig, CryptoProvider,
presentation::PresentationOutput,
@@ -8,11 +6,12 @@ use tlsn_attestation::{
};
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
hash::{Blake3, Blinder, HashAlgId},
fixtures::{self, ConnectionFixture, encoder_secret},
hash::Blake3,
transcript::{
Direction, Transcript, TranscriptCommitment, TranscriptSecret,
hash::{PlaintextHash, PlaintextHashSecret, hash_plaintext},
Direction, Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
TranscriptSecret,
encoding::{EncodingCommitment, EncodingTree},
},
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -20,7 +19,6 @@ use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
/// Tests that the attestation protocol and verification work end-to-end
#[test]
fn test_api() {
let mut rng = StdRng::seed_from_u64(0);
let mut provider = CryptoProvider::default();
// Configure signer for Notary
@@ -28,6 +26,8 @@ fn test_api() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let (sent_len, recv_len) = transcript.len();
// Plaintext encodings which the Prover obtained from GC evaluation
let encodings_provider = fixtures::encoding_provider(GET_WITH_HEADER, OK_JSON);
// At the end of the TLS connection the Prover holds the:
let ConnectionFixture {
@@ -44,38 +44,26 @@ fn test_api() {
unreachable!()
};
// Create hash commitments
let hasher = Blake3::default();
let sent_blinder: Blinder = rng.random();
let recv_blinder: Blinder = rng.random();
// Prover specifies the ranges it wants to commit to.
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
transcript_commitment_builder
.commit_sent(&(0..sent_len))
.unwrap()
.commit_recv(&(0..recv_len))
.unwrap();
let sent_idx = RangeSet::from(0..sent_len);
let recv_idx = RangeSet::from(0..recv_len);
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let sent_hash_commitment = PlaintextHash {
direction: Direction::Sent,
idx: sent_idx.clone(),
hash: hash_plaintext(&hasher, transcript.sent(), &sent_blinder),
};
// Prover constructs encoding tree.
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encodings_provider,
)
.unwrap();
let recv_hash_commitment = PlaintextHash {
direction: Direction::Received,
idx: recv_idx.clone(),
hash: hash_plaintext(&hasher, transcript.received(), &recv_blinder),
};
let sent_hash_secret = PlaintextHashSecret {
direction: Direction::Sent,
idx: sent_idx,
alg: HashAlgId::BLAKE3,
blinder: sent_blinder,
};
let recv_hash_secret = PlaintextHashSecret {
direction: Direction::Received,
idx: recv_idx,
alg: HashAlgId::BLAKE3,
blinder: recv_blinder,
let encoding_commitment = EncodingCommitment {
root: encoding_tree.root(),
};
let request_config = RequestConfig::default();
@@ -86,14 +74,8 @@ fn test_api() {
.handshake_data(server_cert_data)
.transcript(transcript)
.transcript_commitments(
vec![
TranscriptSecret::Hash(sent_hash_secret),
TranscriptSecret::Hash(recv_hash_secret),
],
vec![
TranscriptCommitment::Hash(sent_hash_commitment.clone()),
TranscriptCommitment::Hash(recv_hash_commitment.clone()),
],
vec![TranscriptSecret::Encoding(encoding_tree)],
vec![TranscriptCommitment::Encoding(encoding_commitment.clone())],
);
let (request, secrets) = request_builder.build(&provider).unwrap();
@@ -113,10 +95,8 @@ fn test_api() {
.connection_info(connection_info.clone())
// Server key Notary received during handshake
.server_ephemeral_key(server_ephemeral_key)
.transcript_commitments(vec![
TranscriptCommitment::Hash(sent_hash_commitment),
TranscriptCommitment::Hash(recv_hash_commitment),
]);
.encoder_secret(encoder_secret())
.transcript_commitments(vec![TranscriptCommitment::Encoding(encoding_commitment)]);
let attestation = attestation_builder.build(&provider).unwrap();

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{ops::Set, set::RangeSet};
use rangeset::{Difference, RangeSet, UnionMut};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError;
@@ -210,12 +210,10 @@ where
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower {
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
@@ -268,7 +266,7 @@ where
mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -284,7 +282,7 @@ where
mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::ops::Set;
use rangeset::Subset;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]

View File

@@ -59,7 +59,5 @@ generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -1,6 +1,6 @@
//! Proving configuration.
use rangeset::set::{RangeSet, ToRangeSet};
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};

View File

@@ -1,11 +1,11 @@
use rangeset::set::RangeSet;
use rangeset::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter() {
for range in self.0.iter_ranges() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;

View File

@@ -1,7 +1,10 @@
//! Fixtures for testing
mod provider;
pub mod transcript;
pub use provider::FixtureEncodingProvider;
use hex::FromHex;
use crate::{
@@ -10,6 +13,10 @@ use crate::{
ServerEphemKey, ServerName, ServerSignature, SignatureAlgorithm, TlsVersion,
TranscriptLength,
},
transcript::{
encoding::{EncoderSecret, EncodingProvider},
Transcript,
},
webpki::CertificateDer,
};
@@ -122,3 +129,27 @@ impl ConnectionFixture {
server_ephemeral_key
}
}
/// Returns an encoding provider fixture.
pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider {
let secret = encoder_secret();
FixtureEncodingProvider::new(&secret, Transcript::new(tx, rx))
}
/// Seed fixture.
const SEED: [u8; 32] = [0; 32];
/// Delta fixture.
const DELTA: [u8; 16] = [1; 16];
/// Returns an encoder secret fixture.
pub fn encoder_secret() -> EncoderSecret {
EncoderSecret::new(SEED, DELTA)
}
/// Returns a tampered encoder secret fixture.
pub fn encoder_secret_tampered_seed() -> EncoderSecret {
let mut seed = SEED;
seed[0] += 1;
EncoderSecret::new(seed, DELTA)
}

View File

@@ -0,0 +1,41 @@
use std::ops::Range;
use crate::transcript::{
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider, EncodingProviderError},
Direction, Transcript,
};
/// A encoding provider fixture.
pub struct FixtureEncodingProvider {
encoder: Box<dyn Encoder>,
transcript: Transcript,
}
impl FixtureEncodingProvider {
/// Creates a new encoding provider fixture.
pub(crate) fn new(secret: &EncoderSecret, transcript: Transcript) -> Self {
Self {
encoder: Box::new(new_encoder(secret)),
transcript,
}
}
}
impl EncodingProvider for FixtureEncodingProvider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let transcript = match direction {
Direction::Sent => &self.transcript.sent(),
Direction::Received => &self.transcript.received(),
};
let data = transcript.get(range.clone()).ok_or(EncodingProviderError)?;
self.encoder.encode_data(direction, range, data, dest);
Ok(())
}
}

View File

@@ -19,7 +19,9 @@ use serde::{Deserialize, Serialize};
use crate::{
connection::ServerName,
transcript::{PartialTranscript, TranscriptCommitment, TranscriptSecret},
transcript::{
encoding::EncoderSecret, PartialTranscript, TranscriptCommitment, TranscriptSecret,
},
};
/// Prover output.
@@ -40,6 +42,8 @@ pub struct VerifierOutput {
pub server_name: Option<ServerName>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Encoding commitment secret.
pub encoder_secret: Option<EncoderSecret>,
/// Transcript commitments.
pub transcript_commitments: Vec<TranscriptCommitment>,
}

View File

@@ -63,6 +63,11 @@ impl MerkleProof {
Ok(())
}
/// Returns the leaf count of the Merkle tree associated with the proof.
pub(crate) fn leaf_count(&self) -> usize {
self.leaf_count
}
}
#[derive(Clone)]

View File

@@ -19,17 +19,14 @@
//! withheld.
mod commit;
pub mod encoding;
pub mod hash;
mod proof;
mod tls;
use std::{fmt, ops::Range};
use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -109,14 +106,8 @@ impl Transcript {
}
Some(
Subsequence::new(
idx.clone(),
data.index(idx).fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
)
.expect("data is same length as index"),
Subsequence::new(idx.clone(), data.index_ranges(idx))
.expect("data is same length as index"),
)
}
@@ -138,11 +129,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
for range in sent_idx.iter() {
for range in sent_idx.iter_ranges() {
sent[range.clone()].copy_from_slice(&self.sent[range]);
}
for range in recv_idx.iter() {
for range in recv_idx.iter_ranges() {
received[range.clone()].copy_from_slice(&self.received[range]);
}
@@ -195,20 +186,12 @@ pub struct CompressedPartialTranscript {
impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self {
Self {
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
Vec::new(),
|mut acc, s| {
acc.extend_from_slice(s);
acc
},
),
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx),
received_authed: uncompressed
.received
.index(&uncompressed.received_authed_idx)
.fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
.index_ranges(&uncompressed.received_authed_idx),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -224,7 +207,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.sent_idx.iter() {
for range in compressed.sent_idx.iter_ranges() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
@@ -232,7 +215,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.recv_idx.iter() {
for range in compressed.recv_idx.iter_ranges() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
@@ -321,16 +304,12 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
(0..self.sent.len()).difference(&self.sent_authed_idx)
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
(0..self.received.len()).difference(&self.received_authed_idx)
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -340,7 +319,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.iter_values().map(move |i| data[i])
authed.iter().map(|i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -360,20 +339,24 @@ impl PartialTranscript {
"received data are not the same length"
);
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
for range in other
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed_idx
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx
.union_mut(&other.received_authed_idx);
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -385,11 +368,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed_idx.union_mut(&seq.idx);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed_idx.union_mut(&seq.idx);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
}
}
}
@@ -400,10 +383,10 @@ impl PartialTranscript {
///
/// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter() {
for range in self.sent_unauthed().iter_ranges() {
self.sent[range].fill(value);
}
for range in self.received_unauthed().iter() {
for range in self.received_unauthed().iter_ranges() {
self.received[range].fill(value);
}
}
@@ -418,13 +401,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for r in range.difference(&self.sent_authed_idx) {
self.sent[r].fill(value);
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
}
}
Direction::Received => {
for r in range.difference(&self.received_authed_idx) {
self.received[r].fill(value);
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
}
}
}
@@ -502,7 +485,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0;
for range in self.idx.iter() {
for range in self.idx.iter_ranges() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len();
}
@@ -627,7 +610,12 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
let end = partial_transcript
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
partial_transcript.sent_total = end - 1;

View File

@@ -2,21 +2,33 @@
use std::{collections::HashSet, fmt};
use rangeset::set::ToRangeSet;
use rangeset::{ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
hash::HashAlgId,
transcript::{
encoding::{EncodingCommitment, EncodingTree},
hash::{PlaintextHash, PlaintextHashSecret},
Direction, RangeSet, Transcript,
},
};
/// The maximum allowed total bytelength of committed data for a single
/// commitment kind. Used to prevent DoS during verification. (May cause the
/// verifier to hash up to a max of 1GB * 128 = 128GB of data for certain kinds
/// of encoding commitments.)
///
/// This value must not exceed bcs's MAX_SEQUENCE_LENGTH limit (which is (1 <<
/// 31) - 1 by default)
pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
/// Kind of transcript commitment.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptCommitmentKind {
/// A commitment to encodings of the transcript.
Encoding,
/// A hash commitment to plaintext in the transcript.
Hash {
/// The hash algorithm used.
@@ -27,6 +39,7 @@ pub enum TranscriptCommitmentKind {
impl fmt::Display for TranscriptCommitmentKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Encoding => f.write_str("encoding"),
Self::Hash { alg } => write!(f, "hash ({alg})"),
}
}
@@ -36,6 +49,8 @@ impl fmt::Display for TranscriptCommitmentKind {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptCommitment {
/// Encoding commitment.
Encoding(EncodingCommitment),
/// Plaintext hash commitment.
Hash(PlaintextHash),
}
@@ -44,6 +59,8 @@ pub enum TranscriptCommitment {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptSecret {
/// Encoding tree.
Encoding(EncodingTree),
/// Plaintext hash secret.
Hash(PlaintextHashSecret),
}
@@ -51,6 +68,9 @@ pub enum TranscriptSecret {
/// Configuration for transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -60,23 +80,53 @@ impl TranscriptCommitConfig {
TranscriptCommitConfigBuilder::new(transcript)
}
/// Returns the hash algorithm to use for encoding commitments.
pub fn encoding_hash_alg(&self) -> &HashAlgId {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.commits
.iter()
.any(|(_, kind)| matches!(kind, TranscriptCommitmentKind::Hash { .. }))
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Encoding => Some(idx),
_ => None,
})
}
/// Returns an iterator over the hash commitment indices.
pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, RangeSet<usize>), &HashAlgId)> {
self.commits.iter().map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Hash { alg } => (idx, alg),
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
_ => None,
})
}
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding.then(|| {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (dir, idx) in self.iter_encoding() {
match dir {
Direction::Sent => sent.union_mut(idx),
Direction::Received => recv.union_mut(idx),
}
}
(sent, recv)
}),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -86,9 +136,15 @@ impl TranscriptCommitConfig {
}
/// A builder for [`TranscriptCommitConfig`].
///
/// The default hash algorithm is [`HashAlgId::BLAKE3`] and the default kind
/// is [`TranscriptCommitmentKind::Encoding`].
#[derive(Debug)]
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -98,13 +154,20 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
default_kind: TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
}
}
/// Sets the hash algorithm to use for encoding commitments.
pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
self.encoding_hash_alg = alg;
self
}
/// Sets the default kind of commitment to use.
pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
self.default_kind = default_kind;
@@ -138,6 +201,11 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
));
}
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
}
self.commits.insert(((direction, idx), kind));
Ok(self)
@@ -184,6 +252,9 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
/// Builds the configuration.
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -230,10 +301,16 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn has_encoding(&self) -> bool {
self.encoding.is_some()
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()
@@ -243,6 +320,11 @@ impl TranscriptCommitRequest {
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
self.hash.iter()
}
/// Returns the ranges of the encoding commitments.
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.encoding.as_ref()
}
}
#[cfg(test)]

View File

@@ -0,0 +1,22 @@
//! Transcript encoding commitments and proofs.
mod encoder;
mod proof;
mod provider;
mod tree;
pub use encoder::{new_encoder, Encoder, EncoderSecret};
pub use proof::{EncodingProof, EncodingProofError};
pub use provider::{EncodingProvider, EncodingProviderError};
pub use tree::{EncodingTree, EncodingTreeError};
use serde::{Deserialize, Serialize};
use crate::hash::TypedHash;
/// Transcript encoding commitment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncodingCommitment {
/// Merkle root of the encoding commitments.
pub root: TypedHash,
}

View File

@@ -0,0 +1,137 @@
use std::ops::Range;
use crate::transcript::Direction;
use itybity::ToBits;
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaCha12Rng;
use serde::{Deserialize, Serialize};
/// The size of the encoding for 1 bit, in bytes.
const BIT_ENCODING_SIZE: usize = 16;
/// The size of the encoding for 1 byte, in bytes.
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],
}
opaque_debug::implement!(EncoderSecret);
impl EncoderSecret {
/// Creates a new secret.
///
/// # Arguments
///
/// * `seed` - The seed for the PRG.
/// * `delta` - Delta for deriving the one-encodings.
pub fn new(seed: [u8; 32], delta: [u8; 16]) -> Self {
Self { seed, delta }
}
/// Returns the seed.
pub fn seed(&self) -> &[u8; 32] {
&self.seed
}
/// Returns the delta.
pub fn delta(&self) -> &[u8; 16] {
&self.delta
}
}
/// Creates a new encoder.
pub fn new_encoder(secret: &EncoderSecret) -> impl Encoder {
ChaChaEncoder::new(secret)
}
pub(crate) struct ChaChaEncoder {
seed: [u8; 32],
delta: [u8; 16],
}
impl ChaChaEncoder {
pub(crate) fn new(secret: &EncoderSecret) -> Self {
let seed = *secret.seed();
let delta = *secret.delta();
Self { seed, delta }
}
pub(crate) fn new_prg(&self, stream_id: u64) -> ChaCha12Rng {
let mut prg = ChaCha12Rng::from_seed(self.seed);
prg.set_stream(stream_id);
prg.set_word_pos(0);
prg
}
}
/// A transcript encoder.
///
/// This is an internal implementation detail that should not be exposed to the
/// public API.
pub trait Encoder {
/// Writes the zero encoding for the given range of the transcript into the
/// destination buffer.
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>);
/// Writes the encoding for the given data into the destination buffer.
fn encode_data(
&self,
direction: Direction,
range: Range<usize>,
data: &[u8],
dest: &mut Vec<u8>,
);
}
impl Encoder for ChaChaEncoder {
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>) {
// ChaCha encoder works with 32-bit words. Each encoded bit is 128 bits long.
const WORDS_PER_BYTE: u128 = 8 * 128 / 32;
let stream_id: u64 = match direction {
Direction::Sent => 0,
Direction::Received => 1,
};
let mut prg = self.new_prg(stream_id);
let len = range.len() * BYTE_ENCODING_SIZE;
let pos = dest.len();
// Write 0s to the destination buffer.
dest.resize(pos + len, 0);
// Fill the destination buffer with the PRG.
prg.set_word_pos(range.start as u128 * WORDS_PER_BYTE);
prg.fill_bytes(&mut dest[pos..pos + len]);
}
fn encode_data(
&self,
direction: Direction,
range: Range<usize>,
data: &[u8],
dest: &mut Vec<u8>,
) {
const ZERO: [u8; 16] = [0; BIT_ENCODING_SIZE];
let pos = dest.len();
// Write the zero encoding for the given range.
self.encode_range(direction, range, dest);
let dest = &mut dest[pos..];
for (pos, bit) in data.iter_lsb0().enumerate() {
// Add the delta to the encoding whenever the encoded bit is 1,
// otherwise add a zero.
let summand = if bit { &self.delta } else { &ZERO };
dest[pos * BIT_ENCODING_SIZE..(pos + 1) * BIT_ENCODING_SIZE]
.iter_mut()
.zip(summand)
.for_each(|(a, b)| *a ^= *b);
}
}
}

View File

@@ -0,0 +1,361 @@
use std::{collections::HashMap, fmt};
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
hash::{Blinder, HashProvider, HashProviderError},
merkle::{MerkleError, MerkleProof},
transcript::{
commit::MAX_TOTAL_COMMITTED_DATA,
encoding::{new_encoder, Encoder, EncoderSecret, EncodingCommitment},
Direction,
},
};
/// An opening of a leaf in the encoding tree.
#[derive(Clone, Serialize, Deserialize)]
pub(super) struct Opening {
pub(super) direction: Direction,
pub(super) idx: RangeSet<usize>,
pub(super) blinder: Blinder,
}
opaque_debug::implement!(Opening);
/// An encoding commitment proof.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(try_from = "validation::EncodingProofUnchecked")]
pub struct EncodingProof {
/// The proof of inclusion of the commitment(s) in the Merkle tree of
/// commitments.
pub(super) inclusion_proof: MerkleProof,
pub(super) openings: HashMap<usize, Opening>,
}
impl EncodingProof {
/// Verifies the proof against the commitment.
///
/// Returns the authenticated indices of the sent and received data,
/// respectively.
///
/// # Arguments
///
/// * `provider` - Hash provider.
/// * `commitment` - Encoding commitment to verify against.
/// * `sent` - Sent data to authenticate.
/// * `recv` - Received data to authenticate.
pub fn verify_with_provider(
&self,
provider: &HashProvider,
secret: &EncoderSecret,
commitment: &EncodingCommitment,
sent: &[u8],
recv: &[u8],
) -> Result<(RangeSet<usize>, RangeSet<usize>), EncodingProofError> {
let hasher = provider.get(&commitment.root.alg)?;
let encoder = new_encoder(secret);
let Self {
inclusion_proof,
openings,
} = self;
let mut leaves = Vec::with_capacity(openings.len());
let mut expected_leaf = Vec::default();
let mut total_opened = 0u128;
let mut auth_sent = RangeSet::default();
let mut auth_recv = RangeSet::default();
for (
id,
Opening {
direction,
idx,
blinder,
},
) in openings
{
// Make sure the amount of data being proved is bounded.
total_opened += idx.len() as u128;
if total_opened > MAX_TOTAL_COMMITTED_DATA as u128 {
return Err(EncodingProofError::new(
ErrorKind::Proof,
"exceeded maximum allowed data",
))?;
}
let (data, auth) = match direction {
Direction::Sent => (sent, &mut auth_sent),
Direction::Received => (recv, &mut auth_recv),
};
// Make sure the ranges are within the bounds of the transcript.
if idx.end().unwrap_or(0) > data.len() {
return Err(EncodingProofError::new(
ErrorKind::Proof,
format!(
"index out of bounds of the transcript ({}): {} > {}",
direction,
idx.end().unwrap_or(0),
data.len()
),
));
}
expected_leaf.clear();
for range in idx.iter_ranges() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
}
expected_leaf.extend_from_slice(blinder.as_bytes());
// Compute the expected hash of the commitment to make sure it is
// present in the merkle tree.
leaves.push((*id, hasher.hash(&expected_leaf)));
auth.union_mut(idx);
}
// Verify that the expected hashes are present in the merkle tree.
//
// This proves the Prover committed to the purported data prior to the encoder
// seed being revealed. Ergo, if the encodings are authentic then the purported
// data is authentic.
inclusion_proof.verify(hasher, &commitment.root, leaves)?;
Ok((auth_sent, auth_recv))
}
}
/// Error for [`EncodingProof`].
#[derive(Debug, thiserror::Error)]
pub struct EncodingProofError {
kind: ErrorKind,
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl EncodingProofError {
fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Provider,
Proof,
}
impl fmt::Display for EncodingProofError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("encoding proof error: ")?;
match self.kind {
ErrorKind::Provider => f.write_str("provider error")?,
ErrorKind::Proof => f.write_str("proof error")?,
}
if let Some(source) = &self.source {
write!(f, " caused by: {source}")?;
}
Ok(())
}
}
impl From<HashProviderError> for EncodingProofError {
fn from(error: HashProviderError) -> Self {
Self::new(ErrorKind::Provider, error)
}
}
impl From<MerkleError> for EncodingProofError {
fn from(error: MerkleError) -> Self {
Self::new(ErrorKind::Proof, error)
}
}
/// Invalid encoding proof error.
#[derive(Debug, thiserror::Error)]
#[error("invalid encoding proof: {0}")]
pub struct InvalidEncodingProof(&'static str);
mod validation {
use super::*;
/// The maximum allowed height of the Merkle tree of encoding commitments.
///
/// The statistical security parameter (SSP) of the encoding commitment
/// protocol is calculated as "the number of uniformly random bits in a
/// single bit's encoding minus `MAX_HEIGHT`".
///
/// For example, a bit encoding used in garbled circuits typically has 127
/// uniformly random bits, hence when using it in the encoding
/// commitment protocol, the SSP is 127 - 30 = 97 bits.
///
/// Leaving this validation here as a fail-safe in case we ever start
/// using shorter encodings.
const MAX_HEIGHT: usize = 30;
#[derive(Debug, Deserialize)]
pub(super) struct EncodingProofUnchecked {
inclusion_proof: MerkleProof,
openings: HashMap<usize, Opening>,
}
impl TryFrom<EncodingProofUnchecked> for EncodingProof {
type Error = InvalidEncodingProof;
fn try_from(unchecked: EncodingProofUnchecked) -> Result<Self, Self::Error> {
if unchecked.inclusion_proof.leaf_count() > 1 << MAX_HEIGHT {
return Err(InvalidEncodingProof(
"the height of the tree exceeds the maximum allowed",
));
}
Ok(Self {
inclusion_proof: unchecked.inclusion_proof,
openings: unchecked.openings,
})
}
}
}
#[cfg(test)]
mod test {
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
use crate::{
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
hash::Blake3,
transcript::{encoding::EncodingTree, Transcript},
};
use super::*;
struct EncodingFixture {
transcript: Transcript,
proof: EncodingProof,
commitment: EncodingCommitment,
}
fn new_encoding_fixture() -> EncodingFixture {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let provider = encoding_provider(transcript.sent(), transcript.received());
let tree = EncodingTree::new(&Blake3::default(), [&idx_0, &idx_1], &provider).unwrap();
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment { root: tree.root() };
EncodingFixture {
transcript,
proof,
commitment,
}
}
#[test]
fn test_verify_encoding_proof_tampered_seed() {
let EncodingFixture {
transcript,
proof,
commitment,
} = new_encoding_fixture();
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret_tampered_seed(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_out_of_range() {
let EncodingFixture {
transcript,
proof,
commitment,
} = new_encoding_fixture();
let sent = &transcript.sent()[transcript.sent().len() - 1..];
let recv = &transcript.received()[transcript.received().len() - 2..];
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
sent,
recv,
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_tampered_idx() {
let EncodingFixture {
transcript,
mut proof,
commitment,
} = new_encoding_fixture();
let Opening { idx, .. } = proof.openings.values_mut().next().unwrap();
*idx = RangeSet::from([0..3, 13..15]);
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_tampered_encoding_blinder() {
let EncodingFixture {
transcript,
mut proof,
commitment,
} = new_encoding_fixture();
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();
*blinder = rand::random();
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
}

View File

@@ -0,0 +1,19 @@
use std::ops::Range;
use crate::transcript::Direction;
/// A provider of plaintext encodings.
pub trait EncodingProvider {
/// Writes the encoding of the given range into the destination buffer.
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError>;
}
/// Error for [`EncodingProvider`].
#[derive(Debug, thiserror::Error)]
#[error("failed to provide encoding")]
pub struct EncodingProviderError;

View File

@@ -0,0 +1,327 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
merkle::MerkleTree,
transcript::{
encoding::{
proof::{EncodingProof, Opening},
EncodingProvider,
},
Direction,
},
};
/// Encoding tree builder error.
#[derive(Debug, thiserror::Error)]
pub enum EncodingTreeError {
/// Index is out of bounds of the transcript.
#[error("index is out of bounds of the transcript")]
OutOfBounds {
/// The index.
index: RangeSet<usize>,
/// The transcript length.
transcript_length: usize,
},
/// Encoding provider is missing an encoding for an index.
#[error("encoding provider is missing an encoding for an index")]
MissingEncoding {
/// The index which is missing.
index: RangeSet<usize>,
},
/// Index is missing from the tree.
#[error("index is missing from the tree")]
MissingLeaf {
/// The index which is missing.
index: RangeSet<usize>,
},
}
/// A merkle tree of transcript encodings.
#[derive(Clone, Serialize, Deserialize)]
pub struct EncodingTree {
/// Merkle tree of the commitments.
tree: MerkleTree,
/// Nonces used to blind the hashes.
blinders: Vec<Blinder>,
/// Mapping between the index of a leaf and the transcript index it
/// corresponds to.
idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
/// Union of all transcript indices in the sent direction.
sent_idx: RangeSet<usize>,
/// Union of all transcript indices in the received direction.
received_idx: RangeSet<usize>,
}
opaque_debug::implement!(EncodingTree);
impl EncodingTree {
/// Creates a new encoding tree.
///
/// # Arguments
///
/// * `hasher` - The hash algorithm to use.
/// * `idxs` - The subsequence indices to commit to.
/// * `provider` - The encoding provider.
pub fn new<'idx>(
hasher: &dyn HashAlgorithm,
idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
provider: &dyn EncodingProvider,
) -> Result<Self, EncodingTreeError> {
let mut this = Self {
tree: MerkleTree::new(hasher.id()),
blinders: Vec::new(),
idxs: BiMap::new(),
sent_idx: RangeSet::default(),
received_idx: RangeSet::default(),
};
let mut leaves = Vec::new();
let mut encoding = Vec::new();
for dir_idx in idxs {
let direction = dir_idx.0;
let idx = &dir_idx.1;
// Ignore empty indices.
if idx.is_empty() {
continue;
}
if this.idxs.contains_right(dir_idx) {
// The subsequence is already in the tree.
continue;
}
let blinder: Blinder = rand::random();
encoding.clear();
for range in idx.iter_ranges() {
provider
.provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
}
encoding.extend_from_slice(blinder.as_bytes());
let leaf = hasher.hash(&encoding);
leaves.push(leaf);
this.blinders.push(blinder);
this.idxs.insert(this.idxs.len(), dir_idx.clone());
match direction {
Direction::Sent => this.sent_idx.union_mut(idx),
Direction::Received => this.received_idx.union_mut(idx),
}
}
this.tree.insert(hasher, leaves);
Ok(this)
}
/// Returns the root of the tree.
pub fn root(&self) -> TypedHash {
self.tree.root()
}
/// Returns the hash algorithm of the tree.
pub fn algorithm(&self) -> HashAlgId {
self.tree.algorithm()
}
/// Generates a proof for the given indices.
///
/// # Arguments
///
/// * `idxs` - The transcript indices to prove.
pub fn proof<'idx>(
&self,
idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
) -> Result<EncodingProof, EncodingTreeError> {
let mut openings = HashMap::new();
for dir_idx in idxs {
let direction = dir_idx.0;
let idx = &dir_idx.1;
let leaf_idx = *self
.idxs
.get_by_right(dir_idx)
.ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
let blinder = self.blinders[leaf_idx].clone();
openings.insert(
leaf_idx,
Opening {
direction,
idx: idx.clone(),
blinder,
},
);
}
let mut indices = openings.keys().copied().collect::<Vec<_>>();
indices.sort();
Ok(EncodingProof {
inclusion_proof: self.tree.proof(&indices),
openings,
})
}
/// Returns whether the tree contains the given transcript index.
pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
self.idxs.contains_right(idx)
}
pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
match direction {
Direction::Sent => &self.sent_idx,
Direction::Received => &self.received_idx,
}
}
/// Returns the committed transcript indices.
pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.idxs.right_values()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
fixtures::{encoder_secret, encoding_provider},
hash::{Blake3, HashProvider},
transcript::{encoding::EncodingCommitment, Transcript},
};
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
fn new_tree<'seq>(
transcript: &Transcript,
idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
) -> Result<EncodingTree, EncodingTreeError> {
let provider = encoding_provider(transcript.sent(), transcript.received());
EncodingTree::new(&Blake3::default(), idxs, &provider)
}
#[test]
fn test_encoding_tree() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
assert!(tree.contains(&idx_0));
assert!(tree.contains(&idx_1));
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment { root: tree.root() };
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, idx_0.1);
assert_eq!(auth_recv, idx_1.1);
}
#[test]
fn test_encoding_tree_multiple_ranges() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..1));
let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
let idx_2 = (Direction::Received, RangeSet::from(0..1));
let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
assert!(tree.contains(&idx_0));
assert!(tree.contains(&idx_1));
assert!(tree.contains(&idx_2));
assert!(tree.contains(&idx_3));
let proof = tree
.proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
.unwrap();
let commitment = EncodingCommitment { root: tree.root() };
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
let mut expected_auth_sent = RangeSet::default();
expected_auth_sent.union_mut(&idx_0.1);
expected_auth_sent.union_mut(&idx_1.1);
let mut expected_auth_recv = RangeSet::default();
expected_auth_recv.union_mut(&idx_2.1);
expected_auth_recv.union_mut(&idx_3.1);
assert_eq!(auth_sent, expected_auth_sent);
assert_eq!(auth_recv, expected_auth_recv);
}
#[test]
fn test_encoding_tree_proof_missing_leaf() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..4));
let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
let result = tree
.proof([&idx_0, &idx_1, &idx_2].into_iter())
.unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
}
#[test]
fn test_encoding_tree_out_of_bounds() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
}
#[test]
fn test_encoding_tree_missing_encoding() {
let provider = encoding_provider(&[], &[]);
let result = EncodingTree::new(
&Blake3::default(),
[(Direction::Sent, RangeSet::from(0..8))].iter(),
&provider,
)
.unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
}
}

View File

@@ -1,10 +1,6 @@
//! Transcript proofs.
use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
@@ -14,6 +10,7 @@ use crate::{
hash::{HashAlgId, HashProvider},
transcript::{
commit::{TranscriptCommitment, TranscriptCommitmentKind},
encoding::{EncoderSecret, EncodingProof, EncodingProofError, EncodingTree},
hash::{hash_plaintext, PlaintextHash, PlaintextHashSecret},
Direction, PartialTranscript, RangeSet, Transcript, TranscriptSecret,
},
@@ -31,12 +28,14 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash {
alg: HashAlgId::KECCAK256,
},
TranscriptCommitmentKind::Encoding,
];
/// Proof of the contents of a transcript.
#[derive(Clone, Serialize, Deserialize)]
pub struct TranscriptProof {
transcript: PartialTranscript,
encoding_proof: Option<EncodingProof>,
hash_secrets: Vec<PlaintextHashSecret>,
}
@@ -50,18 +49,27 @@ impl TranscriptProof {
/// # Arguments
///
/// * `provider` - The hash provider to use for verification.
/// * `length` - The transcript length.
/// * `commitments` - The commitments to verify against.
/// * `attestation_body` - The attestation body to verify against.
pub fn verify_with_provider<'a>(
self,
provider: &HashProvider,
length: &TranscriptLength,
encoder_secret: Option<&EncoderSecret>,
commitments: impl IntoIterator<Item = &'a TranscriptCommitment>,
) -> Result<PartialTranscript, TranscriptProofError> {
let mut encoding_commitment = None;
let mut hash_commitments = HashSet::new();
// Index commitments.
for commitment in commitments {
match commitment {
TranscriptCommitment::Encoding(commitment) => {
if encoding_commitment.replace(commitment).is_some() {
return Err(TranscriptProofError::new(
ErrorKind::Encoding,
"multiple encoding commitments are present.",
));
}
}
TranscriptCommitment::Hash(plaintext_hash) => {
hash_commitments.insert(plaintext_hash);
}
@@ -80,6 +88,34 @@ impl TranscriptProof {
let mut total_auth_sent = RangeSet::default();
let mut total_auth_recv = RangeSet::default();
// Verify encoding proof.
if let Some(proof) = self.encoding_proof {
let secret = encoder_secret.ok_or_else(|| {
TranscriptProofError::new(
ErrorKind::Encoding,
"contains an encoding proof but missing encoder secret",
)
})?;
let commitment = encoding_commitment.ok_or_else(|| {
TranscriptProofError::new(
ErrorKind::Encoding,
"contains an encoding proof but missing encoding commitment",
)
})?;
let (auth_sent, auth_recv) = proof.verify_with_provider(
provider,
secret,
commitment,
self.transcript.sent_unsafe(),
self.transcript.received_unsafe(),
)?;
total_auth_sent.union_mut(&auth_sent);
total_auth_recv.union_mut(&auth_recv);
}
let mut buffer = Vec::new();
for PlaintextHashSecret {
direction,
@@ -108,7 +144,7 @@ impl TranscriptProof {
}
buffer.clear();
for range in idx.iter() {
for range in idx.iter_ranges() {
buffer.extend_from_slice(&plaintext[range]);
}
@@ -163,6 +199,7 @@ impl TranscriptProofError {
#[derive(Debug)]
enum ErrorKind {
Encoding,
Hash,
Proof,
}
@@ -172,6 +209,7 @@ impl fmt::Display for TranscriptProofError {
f.write_str("transcript proof error: ")?;
match self.kind {
ErrorKind::Encoding => f.write_str("encoding error")?,
ErrorKind::Hash => f.write_str("hash error")?,
ErrorKind::Proof => f.write_str("proof error")?,
}
@@ -184,6 +222,12 @@ impl fmt::Display for TranscriptProofError {
}
}
impl From<EncodingProofError> for TranscriptProofError {
fn from(e: EncodingProofError) -> Self {
TranscriptProofError::new(ErrorKind::Encoding, e)
}
}
/// Union of ranges to reveal.
#[derive(Clone, Debug, PartialEq)]
struct QueryIdx {
@@ -228,6 +272,7 @@ pub struct TranscriptProofBuilder<'a> {
/// Commitment kinds in order of preference for building transcript proofs.
commitment_kinds: Vec<TranscriptCommitmentKind>,
transcript: &'a Transcript,
encoding_tree: Option<&'a EncodingTree>,
hash_secrets: Vec<&'a PlaintextHashSecret>,
committed_sent: RangeSet<usize>,
committed_recv: RangeSet<usize>,
@@ -243,9 +288,15 @@ impl<'a> TranscriptProofBuilder<'a> {
let mut committed_sent = RangeSet::default();
let mut committed_recv = RangeSet::default();
let mut encoding_tree = None;
let mut hash_secrets = Vec::new();
for secret in secrets {
match secret {
TranscriptSecret::Encoding(tree) => {
committed_sent.union_mut(tree.idx(Direction::Sent));
committed_recv.union_mut(tree.idx(Direction::Received));
encoding_tree = Some(tree);
}
TranscriptSecret::Hash(hash) => {
match hash.direction {
Direction::Sent => committed_sent.union_mut(&hash.idx),
@@ -259,6 +310,7 @@ impl<'a> TranscriptProofBuilder<'a> {
Self {
commitment_kinds: DEFAULT_COMMITMENT_KINDS.to_vec(),
transcript,
encoding_tree,
hash_secrets,
committed_sent,
committed_recv,
@@ -314,7 +366,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx);
} else {
let missing = idx.difference(committed).into_set();
let missing = idx.difference(committed);
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!(
@@ -356,6 +408,7 @@ impl<'a> TranscriptProofBuilder<'a> {
transcript: self
.transcript
.to_partial(self.query_idx.sent.clone(), self.query_idx.recv.clone()),
encoding_proof: None,
hash_secrets: Vec::new(),
};
let mut uncovered_query_idx = self.query_idx.clone();
@@ -367,6 +420,46 @@ impl<'a> TranscriptProofBuilder<'a> {
// self.commitment_kinds.
if let Some(kind) = commitment_kinds_iter.next() {
match kind {
TranscriptCommitmentKind::Encoding => {
let Some(encoding_tree) = self.encoding_tree else {
// Proceeds to the next preferred commitment kind if encoding tree is
// not available.
continue;
};
let (sent_dir_idxs, sent_uncovered) = uncovered_query_idx.sent.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Sent),
|(_, idx)| idx,
);
// Uncovered ranges will be checked with ranges of the next
// preferred commitment kind.
uncovered_query_idx.sent = sent_uncovered;
let (recv_dir_idxs, recv_uncovered) = uncovered_query_idx.recv.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Received),
|(_, idx)| idx,
);
uncovered_query_idx.recv = recv_uncovered;
let dir_idxs = sent_dir_idxs
.into_iter()
.chain(recv_dir_idxs)
.collect::<Vec<_>>();
// Skip proof generation if there are no committed ranges that can cover the
// query ranges.
if !dir_idxs.is_empty() {
transcript_proof.encoding_proof = Some(
encoding_tree
.proof(dir_idxs.into_iter())
.expect("subsequences were checked to be in tree"),
);
}
}
TranscriptCommitmentKind::Hash { alg } => {
let (sent_hashes, sent_uncovered) = uncovered_query_idx.sent.cover_by(
self.hash_secrets.iter().filter(|hash| {
@@ -489,14 +582,50 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rangeset::prelude::*;
use rangeset::RangeSet;
use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
use crate::hash::{Blinder, HashAlgId};
use crate::{
fixtures::{encoder_secret, encoding_provider},
hash::{Blake3, Blinder, HashAlgId},
transcript::TranscriptCommitConfigBuilder,
};
use super::*;
#[rstest]
fn test_verify_missing_encoding_commitment_root() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let idxs = vec![(Direction::Received, RangeSet::from(0..transcript.len().1))];
let encoding_tree = EncodingTree::new(
&Blake3::default(),
&idxs,
&encoding_provider(transcript.sent(), transcript.received()),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
builder.reveal_recv(&(0..transcript.len().1)).unwrap();
let transcript_proof = builder.build().unwrap();
let provider = HashProvider::default();
let err = transcript_proof
.verify_with_provider(
&provider,
&transcript.length(),
Some(&encoder_secret()),
&[],
)
.err()
.unwrap();
assert!(matches!(err.kind, ErrorKind::Encoding));
}
#[rstest]
fn test_reveal_range_out_of_bounds() {
let transcript = Transcript::new(
@@ -516,7 +645,7 @@ mod tests {
}
#[rstest]
fn test_reveal_missing_commitment() {
fn test_reveal_missing_encoding_tree() {
let transcript = Transcript::new(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
@@ -565,6 +694,7 @@ mod tests {
.verify_with_provider(
&provider,
&transcript.length(),
None,
&[TranscriptCommitment::Hash(commitment)],
)
.unwrap();
@@ -614,6 +744,7 @@ mod tests {
.verify_with_provider(
&provider,
&transcript.length(),
None,
&[TranscriptCommitment::Hash(commitment)],
)
.unwrap_err();
@@ -629,19 +760,24 @@ mod tests {
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Encoding,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Encoding,
]);
assert_eq!(
builder.commitment_kinds,
vec![TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256
},]
vec![
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256
},
TranscriptCommitmentKind::Encoding
]
);
}
@@ -651,7 +787,7 @@ mod tests {
RangeSet::from([0..10, 12..30]),
true,
)]
#[case::reveal_all_rangesets_with_single_superset_range(
#[case::reveal_all_rangesets_with_superset_ranges(
vec![RangeSet::from([0..1]), RangeSet::from([1..2, 8..9]), RangeSet::from([2..4, 6..8]), RangeSet::from([2..3, 6..7]), RangeSet::from([9..12])],
RangeSet::from([0..4, 6..9]),
true,
@@ -682,30 +818,29 @@ mod tests {
false,
)]
#[allow(clippy::single_range_in_vec_init)]
fn test_reveal_multiple_rangesets_with_one_rangeset(
fn test_reveal_mutliple_rangesets_with_one_rangeset(
#[case] commit_recv_rangesets: Vec<RangeSet<usize>>,
#[case] reveal_recv_rangeset: RangeSet<usize>,
#[case] success: bool,
) {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
// Create hash commitments for each rangeset
let mut secrets = Vec::new();
// Encoding commitment kind
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
for rangeset in commit_recv_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Received,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_recv(rangeset).unwrap();
}
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encoding_provider(GET_WITH_HEADER, OK_JSON),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
if success {
@@ -758,34 +893,27 @@ mod tests {
#[case] uncovered_sent_rangeset: RangeSet<usize>,
#[case] uncovered_recv_rangeset: RangeSet<usize>,
) {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
// Create hash commitments for each rangeset
let mut secrets = Vec::new();
// Encoding commitment kind
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
for rangeset in commit_sent_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Sent,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_sent(rangeset).unwrap();
}
for rangeset in commit_recv_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Received,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_recv(rangeset).unwrap();
}
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encoding_provider(GET_WITH_HEADER, OK_JSON),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
builder.reveal_sent(&reveal_sent_rangeset).unwrap();
builder.reveal_recv(&reveal_recv_rangeset).unwrap();

View File

@@ -332,6 +332,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let (
VerifierOutput {
transcript_commitments,
encoder_secret,
..
},
verifier,
@@ -392,6 +393,10 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments);
if let Some(encoder_secret) = encoder_secret {
builder.encoder_secret(encoder_secret);
}
let attestation = builder.build(&provider)?;
// Send attestation to prover.

View File

@@ -1,59 +1,51 @@
#### Default Representative Benchmarks ####
#
# This benchmark measures TLSNotary performance on three representative network scenarios.
# Each scenario is run multiple times to produce statistical metrics (median, std dev, etc.)
# rather than plots. Use this for quick performance checks and CI regression testing.
#
# Payload sizes:
# - upload-size: 1KB (typical HTTP request)
# - download-size: 2KB (typical HTTP response/API data)
#
# Network scenarios are chosen to represent real-world user conditions where
# TLSNotary is primarily bottlenecked by upload bandwidth.
#### Cable/DSL Home Internet ####
# Most common residential internet connection
# - Asymmetric: high download, limited upload (typical bottleneck)
# - Upload bandwidth: 20 Mbps (realistic cable/DSL upload speed)
# - Latency: 20ms (typical ISP latency)
#### Latency ####
[[group]]
name = "cable"
bandwidth = 20
protocol_latency = 20
upload-size = 1024
download-size = 2048
name = "latency"
bandwidth = 1000
[[bench]]
group = "cable"
#### Mobile 5G ####
# Modern mobile connection with good coverage
# - Upload bandwidth: 30 Mbps (typical 5G upload in good conditions)
# - Latency: 30ms (higher than wired due to mobile tower hops)
[[group]]
name = "mobile_5g"
bandwidth = 30
protocol_latency = 30
upload-size = 1024
download-size = 2048
group = "latency"
protocol_latency = 10
[[bench]]
group = "mobile_5g"
group = "latency"
protocol_latency = 25
#### Fiber Home Internet ####
# High-end residential connection (best case scenario)
# - Symmetric: equal upload/download bandwidth
# - Upload bandwidth: 100 Mbps (typical fiber upload)
# - Latency: 15ms (lower latency than cable)
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200
#### Bandwidth ####
[[group]]
name = "fiber"
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
protocol_latency = 15
upload-size = 1024
download-size = 2048
[[bench]]
group = "fiber"
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -1,52 +0,0 @@
#### Bandwidth Sweep Benchmark ####
#
# Measures how network bandwidth affects TLSNotary runtime.
# Keeps latency and payload sizes fixed while varying upload bandwidth.
#
# Fixed parameters:
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Bandwidth from 5 Mbps to 1000 Mbps
#
# Use this to plot "Bandwidth vs Runtime" and understand bandwidth sensitivity.
# Focus on upload bandwidth as TLSNotary is primarily upload-bottlenecked
[[group]]
name = "bandwidth_sweep"
protocol_latency = 25
upload-size = 1024
download-size = 2048
[[bench]]
group = "bandwidth_sweep"
bandwidth = 5
[[bench]]
group = "bandwidth_sweep"
bandwidth = 10
[[bench]]
group = "bandwidth_sweep"
bandwidth = 20
[[bench]]
group = "bandwidth_sweep"
bandwidth = 50
[[bench]]
group = "bandwidth_sweep"
bandwidth = 100
[[bench]]
group = "bandwidth_sweep"
bandwidth = 250
[[bench]]
group = "bandwidth_sweep"
bandwidth = 500
[[bench]]
group = "bandwidth_sweep"
bandwidth = 1000

View File

@@ -1,53 +0,0 @@
#### Download Size Sweep Benchmark ####
#
# Measures how download payload size affects TLSNotary runtime.
# Keeps network conditions fixed while varying the response size.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request size)
#
# Variable: Download size from 1KB to 100KB
#
# Use this to plot "Download Size vs Runtime" and understand how much data
# TLSNotary can efficiently notarize. Useful for determining optimal
# chunking strategies for large responses.
[[group]]
name = "download_sweep"
bandwidth = 100
protocol_latency = 25
upload-size = 1024
[[bench]]
group = "download_sweep"
download-size = 1024
[[bench]]
group = "download_sweep"
download-size = 2048
[[bench]]
group = "download_sweep"
download-size = 5120
[[bench]]
group = "download_sweep"
download-size = 10240
[[bench]]
group = "download_sweep"
download-size = 20480
[[bench]]
group = "download_sweep"
download-size = 30720
[[bench]]
group = "download_sweep"
download-size = 40960
[[bench]]
group = "download_sweep"
download-size = 51200

View File

@@ -1,47 +0,0 @@
#### Latency Sweep Benchmark ####
#
# Measures how network latency affects TLSNotary runtime.
# Keeps bandwidth and payload sizes fixed while varying protocol latency.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Protocol latency from 10ms to 200ms
#
# Use this to plot "Latency vs Runtime" and understand latency sensitivity.
[[group]]
name = "latency_sweep"
bandwidth = 100
upload-size = 1024
download-size = 2048
[[bench]]
group = "latency_sweep"
protocol_latency = 10
[[bench]]
group = "latency_sweep"
protocol_latency = 25
[[bench]]
group = "latency_sweep"
protocol_latency = 50
[[bench]]
group = "latency_sweep"
protocol_latency = 75
[[bench]]
group = "latency_sweep"
protocol_latency = 100
[[bench]]
group = "latency_sweep"
protocol_latency = 150
[[bench]]
group = "latency_sweep"
protocol_latency = 200

View File

@@ -9,7 +9,6 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench {
group: None,
@@ -21,7 +20,6 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096,
defer_decryption: true,
memory_profile: false,
reveal_all: true,
};
#[derive(Deserialize)]
@@ -81,8 +79,6 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -101,8 +97,6 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
impl BenchItem {
@@ -138,10 +132,6 @@ impl BenchItem {
if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile;
}
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
}
pub fn into_bench(&self) -> Bench {
@@ -155,7 +145,6 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
}
}
}
@@ -175,8 +164,6 @@ pub struct Bench {
pub defer_decryption: bool,
#[serde(rename = "memory-profile")]
pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -22,10 +22,7 @@ pub enum CmdOutput {
GetTests(Vec<String>),
Test(TestOutput),
Bench(BenchOutput),
#[cfg(target_arch = "wasm32")]
Fail {
reason: Option<String>,
},
Fail { reason: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -98,27 +98,14 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder
.server_identity()
.reveal_sent(&reveal_sent_range)?
.reveal_recv(&reveal_recv_range)?;
.reveal_sent(&(0..sent_len))?
.reveal_recv(&(0..recv_len))?;
let prove_config = builder.build()?;
let config = builder.build()?;
prover.prove(&prove_config).await?;
prover.prove(&config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -22,7 +22,6 @@ clap = { workspace = true, features = ["derive", "env"] }
csv = { version = "1.3" }
duct = { version = "1" }
futures = { workspace = true }
indicatif = { version = "0.17" }
ipnet = { workspace = true }
serio = { workspace = true }
serde_json = { workspace = true }

View File

@@ -16,10 +16,6 @@ pub struct Cli {
/// Subnet to assign harness network interfaces.
#[arg(long, default_value = "10.250.0.0/24", env = "SUBNET")]
pub subnet: Ipv4Net,
/// Run browser in headed mode (visible window) for debugging.
/// Works with both X11 and Wayland.
#[arg(long)]
pub headed: bool,
}
#[derive(Subcommand)]
@@ -35,13 +31,10 @@ pub enum Command {
},
/// runs benchmarks.
Bench {
/// Configuration path. Defaults to bench.toml which contains
/// representative scenarios (cable, 5G, fiber) for quick performance
/// checks. Use bench_*_sweep.toml files for parametric
/// analysis.
/// Configuration path.
#[arg(short, long, default_value = "bench.toml")]
config: PathBuf,
/// Output CSV file path for detailed metrics and post-processing.
/// Output file path.
#[arg(short, long, default_value = "metrics.csv")]
output: PathBuf,
/// Number of samples to measure per benchmark. This is overridden by

View File

@@ -28,9 +28,6 @@ pub struct Executor {
ns: Namespace,
config: ExecutorConfig,
target: Target,
/// Display environment variables for headed mode (X11/Wayland).
/// Empty means headless mode.
display_env: Vec<String>,
state: State,
}
@@ -52,17 +49,11 @@ impl State {
}
impl Executor {
pub fn new(
ns: Namespace,
config: ExecutorConfig,
target: Target,
display_env: Vec<String>,
) -> Self {
pub fn new(ns: Namespace, config: ExecutorConfig, target: Target) -> Self {
Self {
ns,
config,
target,
display_env,
state: State::Init,
}
}
@@ -129,49 +120,23 @@ impl Executor {
let tmp = duct::cmd!("mktemp", "-d").read()?;
let tmp = tmp.trim();
let headed = !self.display_env.is_empty();
// Build command args based on headed/headless mode
let mut args: Vec<String> = vec![
"ip".into(),
"netns".into(),
"exec".into(),
self.ns.name().into(),
];
if headed {
// For headed mode: drop back to the current user and pass display env vars
// This allows the browser to connect to X11/Wayland while in the namespace
let user =
std::env::var("USER").context("USER environment variable not set")?;
args.extend(["sudo".into(), "-E".into(), "-u".into(), user, "env".into()]);
args.extend(self.display_env.clone());
}
args.push(chrome_path.to_string_lossy().into());
args.push(format!("--remote-debugging-port={PORT_BROWSER}"));
if headed {
// Headed mode: no headless, add flags to suppress first-run dialogs
args.extend(["--no-first-run".into(), "--no-default-browser-check".into()]);
} else {
// Headless mode: original flags
args.extend([
"--headless".into(),
"--disable-dev-shm-usage".into(),
"--disable-gpu".into(),
"--disable-cache".into(),
"--disable-application-cache".into(),
]);
}
args.extend([
"--no-sandbox".into(),
let process = duct::cmd!(
"sudo",
"ip",
"netns",
"exec",
self.ns.name(),
chrome_path,
format!("--remote-debugging-port={PORT_BROWSER}"),
"--headless",
"--disable-dev-shm-usage",
"--disable-gpu",
"--disable-cache",
"--disable-application-cache",
"--no-sandbox",
format!("--user-data-dir={tmp}"),
"--allowed-ips=10.250.0.1".into(),
]);
let process = duct::cmd("sudo", &args);
format!("--allowed-ips=10.250.0.1"),
);
let process = if !cfg!(feature = "debug") {
process.stderr_capture().stdout_capture().start()?

View File

@@ -9,7 +9,7 @@ mod ws_proxy;
#[cfg(feature = "debug")]
mod debug_prelude;
use std::{collections::HashMap, time::Duration};
use std::time::Duration;
use anyhow::Result;
use clap::Parser;
@@ -22,7 +22,6 @@ use harness_core::{
rpc::{BenchCmd, TestCmd},
test::TestStatus,
};
use indicatif::{ProgressBar, ProgressStyle};
use cli::{Cli, Command};
use executor::Executor;
@@ -33,60 +32,6 @@ use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
/// Statistics for a benchmark configuration
#[derive(Debug, Clone)]
struct BenchStats {
group: Option<String>,
bandwidth: usize,
latency: usize,
upload_size: usize,
download_size: usize,
times: Vec<u64>,
}
impl BenchStats {
fn median(&self) -> f64 {
let mut sorted = self.times.clone();
sorted.sort();
let len = sorted.len();
if len == 0 {
return 0.0;
}
if len.is_multiple_of(2) {
(sorted[len / 2 - 1] + sorted[len / 2]) as f64 / 2.0
} else {
sorted[len / 2] as f64
}
}
}
/// Print summary table of benchmark results
fn print_bench_summary(stats: &[BenchStats]) {
if stats.is_empty() {
println!("\nNo benchmark results to display (only warmup was run).");
return;
}
println!("\n{}", "=".repeat(80));
println!("TLSNotary Benchmark Results");
println!("{}", "=".repeat(80));
println!();
for stat in stats {
let group_name = stat.group.as_deref().unwrap_or("unnamed");
println!(
"{} ({} Mbps, {}ms latency, {}KB↑ {}KB↓):",
group_name,
stat.bandwidth,
stat.latency,
stat.upload_size / 1024,
stat.download_size / 1024
);
println!(" Median: {:.2}s", stat.median() / 1000.0);
println!();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
pub enum Target {
#[default]
@@ -105,46 +50,14 @@ struct Runner {
started: bool,
}
/// Collects display-related environment variables for headed browser mode.
/// Works with both X11 and Wayland by collecting whichever vars are present.
fn collect_display_env_vars() -> Vec<String> {
const DISPLAY_VARS: &[&str] = &[
"DISPLAY", // X11
"XAUTHORITY", // X11 auth
"WAYLAND_DISPLAY", // Wayland
"XDG_RUNTIME_DIR", // Wayland runtime dir
];
DISPLAY_VARS
.iter()
.filter_map(|&var| {
std::env::var(var)
.ok()
.map(|val| format!("{}={}", var, val))
})
.collect()
}
impl Runner {
fn new(cli: &Cli) -> Result<Self> {
let Cli {
target,
subnet,
headed,
..
} = cli;
let Cli { target, subnet, .. } = cli;
let current_path = std::env::current_exe().unwrap();
let fixture_path = current_path.parent().unwrap().join("server-fixture");
let network_config = NetworkConfig::new(*subnet);
let network = Network::new(network_config.clone())?;
// Collect display env vars once if headed mode is enabled
let display_env = if *headed {
collect_display_env_vars()
} else {
Vec::new()
};
let server_fixture =
ServerFixture::new(fixture_path, network.ns_app().clone(), network_config.app);
let wasm_server = WasmServer::new(
@@ -162,7 +75,6 @@ impl Runner {
.network_config(network_config.clone())
.build(),
*target,
display_env.clone(),
);
let exec_v = Executor::new(
network.ns_1().clone(),
@@ -172,7 +84,6 @@ impl Runner {
.network_config(network_config.clone())
.build(),
Target::Native,
Vec::new(), // Verifier doesn't need display env
);
Ok(Self {
@@ -207,12 +118,6 @@ pub async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
// Validate --headed requires --target browser
if cli.headed && cli.target != Target::Browser {
anyhow::bail!("--headed can only be used with --target browser");
}
let mut runner = Runner::new(&cli)?;
let mut exit_code = 0;
@@ -301,12 +206,6 @@ pub async fn main() -> Result<()> {
samples_override,
skip_warmup,
} => {
// Print configuration info
println!("TLSNotary Benchmark Harness");
println!("Running benchmarks from: {}", config.display());
println!("Output will be written to: {}", output.display());
println!();
let items: BenchItems = toml::from_str(&std::fs::read_to_string(config)?)?;
let output_file = std::fs::File::create(output)?;
let mut writer = WriterBuilder::new().from_writer(output_file);
@@ -321,34 +220,7 @@ pub async fn main() -> Result<()> {
runner.exec_p.start().await?;
runner.exec_v.start().await?;
// Create progress bar
let pb = ProgressBar::new(benches.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
.expect("valid template")
.progress_chars("█▓▒░ "),
);
// Collect measurements for stats
let mut measurements_by_config: HashMap<String, Vec<u64>> = HashMap::new();
let warmup_count = if skip_warmup { 0 } else { 3 };
for (idx, config) in benches.iter().enumerate() {
let is_warmup = idx < warmup_count;
let group_name = if is_warmup {
format!("Warmup {}/{}", idx + 1, warmup_count)
} else {
config.group.as_deref().unwrap_or("unnamed").to_string()
};
pb.set_message(format!(
"{} ({} Mbps, {}ms)",
group_name, config.bandwidth, config.protocol_latency
));
for config in benches {
runner
.network
.set_proto_config(config.bandwidth, config.protocol_latency.div_ceil(2))?;
@@ -377,73 +249,11 @@ pub async fn main() -> Result<()> {
panic!("expected prover output");
};
// Collect metrics for stats (skip warmup benches)
if !is_warmup {
let config_key = format!(
"{:?}|{}|{}|{}|{}",
config.group,
config.bandwidth,
config.protocol_latency,
config.upload_size,
config.download_size
);
measurements_by_config
.entry(config_key)
.or_default()
.push(metrics.time_total);
}
let measurement = Measurement::new(config.clone(), metrics);
let measurement = Measurement::new(config, metrics);
writer.serialize(measurement)?;
writer.flush()?;
pb.inc(1);
}
pb.finish_with_message("Benchmarks complete");
// Compute and print statistics
let mut all_stats: Vec<BenchStats> = Vec::new();
for (key, times) in measurements_by_config {
// Parse back the config from the key
let parts: Vec<&str> = key.split('|').collect();
if parts.len() >= 5 {
let group = if parts[0] == "None" {
None
} else {
Some(
parts[0]
.trim_start_matches("Some(\"")
.trim_end_matches("\")")
.to_string(),
)
};
let bandwidth: usize = parts[1].parse().unwrap_or(0);
let latency: usize = parts[2].parse().unwrap_or(0);
let upload_size: usize = parts[3].parse().unwrap_or(0);
let download_size: usize = parts[4].parse().unwrap_or(0);
all_stats.push(BenchStats {
group,
bandwidth,
latency,
upload_size,
download_size,
times,
});
}
}
// Sort stats by group name for consistent output
all_stats.sort_by(|a, b| {
a.group
.cmp(&b.group)
.then(a.latency.cmp(&b.latency))
.then(a.bandwidth.cmp(&b.bandwidth))
});
print_bench_summary(&all_stats);
}
Command::Serve {} => {
runner.start_services().await?;

View File

@@ -34,7 +34,6 @@ mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true }
async-trait = { workspace = true }
@@ -66,7 +65,6 @@ rand_chacha = { workspace = true }
rstest = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }

View File

@@ -15,13 +15,6 @@ impl MpcTlsError {
Self(ErrorRepr::Peer(err.into()))
}
pub(crate) fn actor<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Actor(err.into()))
}
pub(crate) fn state<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
@@ -72,8 +65,6 @@ enum ErrorRepr {
Peer(Box<dyn std::error::Error + Send + Sync>),
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("actor error: {0}")]
Actor(Box<dyn std::error::Error + Send + Sync>),
#[error("state error: {0}")]
State(Box<dyn std::error::Error + Send + Sync>),
#[error("allocation error: {0}")]

View File

@@ -1,5 +1,3 @@
mod actor;
use crate::{
error::MpcTlsError,
msg::{
@@ -14,7 +12,6 @@ use async_trait::async_trait;
use hmac_sha256::{MpcPrf, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use ludi::Context as LudiContext;
use mpz_common::{Context, Flush};
use mpz_core::{bitvec::BitVec, Block};
use mpz_memory_core::DecodeFutureTyped;
@@ -50,13 +47,9 @@ use tlsn_core::{
};
use tracing::{debug, instrument, trace, warn};
/// Controller for MPC-TLS leader.
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
/// MPC-TLS leader.
#[derive(Debug)]
pub struct MpcTlsLeader {
self_handle: Option<LeaderCtrl>,
config: Config,
state: State,
@@ -114,7 +107,6 @@ impl MpcTlsLeader {
let is_decrypting = !config.defer_decryption;
Self {
self_handle: None,
config,
state: State::Init {
ctx,
@@ -378,18 +370,42 @@ impl MpcTlsLeader {
Ok(())
}
/// Defers decryption of any incoming messages.
/// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
#[instrument(level = "debug", skip_all, err)]
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
self.is_decrypting = false;
self.notifier.clear();
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = enable;
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
Ok(())
}
/// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop();
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Returns the context and transcript.
///
/// Should be called after a successful call to [`Backend::server_closed`].
pub fn finish(&mut self) -> Option<(Context, TlsTranscript)> {
match self.state.take() {
State::Closed {
ctx, transcript, ..
} => Some((ctx, transcript)),
state => {
self.state = state;
None
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ pub(crate) mod utils;
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
pub use error::MpcTlsError;
pub use follower::MpcTlsFollower;
pub use leader::{LeaderCtrl, MpcTlsLeader};
pub use leader::MpcTlsLeader;
use std::{future::Future, pin::Pin, sync::Arc};

View File

@@ -1,160 +0,0 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.max_recv(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
rcot_recv_a,
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b,
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
);
(leader, follower)
}

View File

@@ -1,39 +0,0 @@
[package]
name = "tlsn-tls-client-async"
authors = ["TLSNotary Team"]
description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "tls_client_async"
[features]
default = ["tracing"]
tracing = ["dep:tracing"]
[dependencies]
tlsn-tls-client = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tokio-util = { workspace = true, features = ["io", "compat"] }
tracing = { workspace = true, optional = true }
[dev-dependencies]
tls-server-fixture = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -1,89 +0,0 @@
use bytes::Bytes;
use futures::{
channel::mpsc::{Receiver, SendError, Sender},
sink::SinkMapErr,
AsyncRead, AsyncWrite, SinkExt,
};
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
pin::Pin,
task::{Context, Poll},
};
use tokio_util::{
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
io::{CopyToBytes, SinkWriter, StreamReader},
};
type CompatSinkWriter =
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
/// A TLS connection to a server.
///
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
#[derive(Debug)]
pub struct TlsConnection {
/// The data to be transmitted to the server is sent to this sink.
tx_sender: CompatSinkWriter,
/// The data to be received from the server is received from this stream.
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
}
impl TlsConnection {
/// Creates a new TLS connection.
pub(crate) fn new(
tx_sender: Sender<Bytes>,
rx_receiver: Receiver<Result<Bytes, IoError>>,
) -> Self {
fn convert_error(err: SendError) -> IoError {
if err.is_disconnected() {
IoErrorKind::BrokenPipe.into()
} else {
IoErrorKind::WouldBlock.into()
}
}
Self {
tx_sender: SinkWriter::new(CopyToBytes::new(
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
))
.compat_write(),
rx_receiver: StreamReader::new(rx_receiver).compat(),
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_close(cx)
}
}

View File

@@ -1,269 +0,0 @@
//! Provides a TLS client which exposes an async socket.
//!
//! This library provides the [bind_client] function which attaches a TLS client
//! to a socket connection and then exposes a [TlsConnection] object, which
//! provides an async socket API for reading and writing cleartext. The TLS
//! client will then automatically encrypt and decrypt traffic and forward that
//! to the provided socket.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod conn;
use bytes::{Buf, Bytes};
use futures::{
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, trace, warn, Instrument};
use tls_client::ClientConnection;
pub use conn::TlsConnection;
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
/// An error that can occur during a TLS connection.
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
TlsError(#[from] tls_client::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
}
/// Closed connection data.
#[derive(Debug)]
pub struct ClosedConnection {
/// The connection for the client
pub client: ClientConnection,
/// Sent plaintext bytes
pub sent: Vec<u8>,
/// Received plaintext bytes
pub recv: Vec<u8>,
}
/// A future which runs the TLS connection to completion.
///
/// This future must be polled in order for the connection to make progress.
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture {
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
}
impl Future for ConnectionFuture {
type Output = Result<ClosedConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
/// Binds a client connection to the provided socket.
///
/// Returns a connection handle and a future which runs the connection to
/// completion.
///
/// # Errors
///
/// Any connection errors that occur will be returned from the future, not
/// [`TlsConnection`].
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
socket: T,
mut client: ClientConnection,
) -> (TlsConnection, ConnectionFuture) {
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
let conn = TlsConnection::new(tx_sender, rx_receiver);
let fut = async move {
client.start().await?;
let mut notify = client.get_notify().await?;
let (mut server_rx, mut server_tx) = socket.split();
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
let mut rx_buf = [0u8; RX_BUF_SIZE];
let mut handshake_done = false;
let mut client_closed = false;
let mut server_closed = false;
let mut sent = Vec::with_capacity(1024);
let mut recv = Vec::with_capacity(1024);
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
// We don't start writing application data until the handshake is complete.
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
// Runs both the tx and rx halves of the connection to completion.
// This loop does not terminate until the *SERVER* closes the connection and
// we've processed all received data. If an error occurs, the `TlsConnection`
// channels will be closed and the error will be returned from this future.
'conn: loop {
// Write all pending TLS data to the server.
if client.wants_write() && !client_closed {
#[cfg(feature = "tracing")]
trace!("client wants to write");
while client.wants_write() {
let _sent = client.write_tls_async(&mut server_tx).await?;
#[cfg(feature = "tracing")]
trace!("sent {} tls bytes to server", _sent);
}
server_tx.flush().await?;
}
// Forward received plaintext to `TlsConnection`.
while !client.plaintext_is_empty() {
let read = client.read_plaintext(&mut rx_buf)?;
recv.extend(&rx_buf[..read]);
// Ignore if the receiver has hung up.
_ = rx_sender
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
.await;
#[cfg(feature = "tracing")]
trace!("forwarded {} plaintext bytes to conn", read);
}
if !client.is_handshaking() && !handshake_done {
#[cfg(feature = "tracing")]
debug!("handshake complete");
handshake_done = true;
// Start reading application data that needs to be transmitted from the
// `TlsConnection`.
tx_recv_fut = tx_receiver.next().fuse();
}
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
break 'conn;
}
select_biased! {
// Reads TLS data from the server and writes it into the client.
received = &mut rx_tls_fut => {
let received = received?;
#[cfg(feature = "tracing")]
trace!("received {} tls bytes from server", received);
// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
let mut reader = rx_tls_buf[..received].reader();
loop {
processed += client.read_tls(&mut reader)?;
client.process_new_packets().await?;
debug_assert!(processed <= received);
if processed >= received {
break;
}
}
#[cfg(feature = "tracing")]
trace!("processed {} tls bytes from server", processed);
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket.
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
server_closed = true;
client.server_closed().await?;
// Do not read from the socket again.
rx_tls_fut = Fuse::terminated();
} else {
// Reset the read future so next iteration we can read again.
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
}
}
// If we receive None from `TlsConnection`, it has closed, so we
// send a close_notify to the server.
data = &mut tx_recv_fut => {
if let Some(data) = data {
#[cfg(feature = "tracing")]
trace!("writing {} plaintext bytes to client", data.len());
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {
if !server_closed {
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
#[cfg(feature = "tracing")]
warn!("failed to send close_notify to server: {}", e);
}
}
client_closed = true;
tx_recv_fut = Fuse::terminated();
}
}
// Waits for a notification from the backend that it is ready to decrypt data.
_ = &mut notify => {
#[cfg(feature = "tracing")]
trace!("backend is ready to decrypt");
client.process_new_packets().await?;
}
}
}
#[cfg(feature = "tracing")]
debug!("client shutdown");
_ = server_tx.close().await;
tx_receiver.close();
rx_sender.close_channel();
#[cfg(feature = "tracing")]
trace!(
"server close notify: {}, sent: {}, recv: {}",
client.received_close_notify(),
sent.len(),
recv.len()
);
Ok(ClosedConnection { client, sent, recv })
};
#[cfg(feature = "tracing")]
let fut = fut.instrument(debug_span!("tls_connection"));
let fut = ConnectionFuture { fut: Box::pin(fut) };
(conn, fut)
}
async fn send_close_notify(
client: &mut ClientConnection,
server_tx: &mut (impl AsyncWrite + Unpin),
) -> Result<(), ConnectionError> {
#[cfg(feature = "tracing")]
trace!("sending close_notify to server");
client.send_close_notify().await?;
client.process_new_packets().await?;
// Flush all remaining plaintext
while client.wants_write() {
client.write_tls_async(server_tx).await?;
}
server_tx.flush().await?;
Ok(())
}

View File

@@ -1,438 +0,0 @@
use std::{str, sync::Arc};
use core::future::Future;
use futures::{AsyncReadExt, AsyncWriteExt};
use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
SERVER_DOMAIN,
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
client_tls_conn: TlsConnection,
// a handle that must be `.await`ed to get the result of a TLS connection
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
}
// Sets up a TLS connection between client and server and sends a hello message
#[fixture]
async fn set_up_tls() -> TlsFixture {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
client_tls_conn
.write_all(&pad("expecting you to send back hello".to_string()))
.await
.unwrap();
// give the server some time to respond
std::thread::sleep(std::time::Duration::from_millis(10));
let mut plaintext = vec![0u8; 320];
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
let s = str::from_utf8(&plaintext[0..n]).unwrap();
assert_eq!(s, "hello");
TlsFixture {
client_tls_conn,
closed_tls_task,
}
}
// Expect the async tls client wrapped in `hyper::client` to make a successful
// request and receive the expected response
#[tokio::test]
async fn test_hyper_ok() {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
.await
.unwrap();
tokio::spawn(connection);
let request = Request::builder()
.uri(format!("https://{SERVER_DOMAIN}/echo"))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Full::<Bytes>::new("hello".into()))
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Process the response body
response.into_body().collect().await.unwrap().to_bytes();
let _ = server_task.await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us AND close the socket
// after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify_and_socket_close(
set_up_tls: impl Future<Output = TlsFixture>,
) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect to be able to read the data after server closes the socket abruptly
#[rstest]
#[tokio::test]
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to send us a hello message
client_tls_conn
.write_all(&pad("send a hello message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to read some more data
let mut buf = vec![0u8; 10];
let n = client_tls_conn.read(&mut buf).await.unwrap();
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
}
// Expect there to be no error when server DOES NOT send close_notify but just
// closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(!closed_conn.client.received_close_notify());
}
// Expect to register a delay when the server delays closing the socket
#[rstest]
#[tokio::test]
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
client_tls_conn
.write_all(&pad("must_delay_when_closing".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client
client_tls_conn.close().await.unwrap();
use std::time::Instant;
let now = Instant::now();
// this will resolve when the server stops delaying closing the socket
let closed_conn = closed_tls_task.await.unwrap().unwrap();
let elapsed = now.elapsed();
// the elapsed time must be roughly equal to the server's delay
// (give or take timing variations)
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
assert!(!closed_conn.client.received_close_notify());
}
// Expect client to error when server sends a corrupted message
#[rstest]
#[tokio::test]
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send a corrupted message
client_tls_conn
.write_all(&pad("send_corrupted_message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received corrupt message"
);
}
// Expect client to error when server sends a TLS record with a bad MAC
#[rstest]
#[tokio::test]
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_record_with_bad_mac".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"backend error: Decryption error: \"aead::Error\""
);
}
// Expect client to error when server sends a fatal alert
#[rstest]
#[tokio::test]
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_alert".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received fatal alert: BadRecordMac"
);
}
// Expect an error when trying to write data to a connection which server closed
// abruptly
#[rstest]
#[tokio::test]
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to send some more data
let res = client_tls_conn
.write_all(&pad("more data".to_string()))
.await;
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
}
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
fn pad(s: String) -> Vec<u8> {
assert!(s.len() <= APP_RECORD_LENGTH);
let mut buf = vec![0u8; APP_RECORD_LENGTH];
buf[..s.len()].copy_from_slice(s.as_bytes());
buf
}

View File

@@ -227,6 +227,7 @@ impl ConnectionCommon {
/// Signals that the server has closed the connection.
pub async fn server_closed(&mut self) -> Result<(), Error> {
self.common_state.has_seen_eof = true;
self.common_state.backend.server_closed().await?;
Ok(())
}
@@ -457,6 +458,9 @@ impl ConnectionCommon {
return Err(Error::CorruptMessage);
}
// Process outgoing plaintext buffer and encrypt messages.
self.flush_plaintext().await?;
// Process new messages.
while let Some(msg) = self.message_deframer.frames.pop_front() {
// If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -508,25 +512,22 @@ impl ConnectionCommon {
Ok(state)
}
/// Write buffer into connection.
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state).await;
/// Writes plaintext `buf` into an internal buffer. May not fully process the
/// whole buffer and returns the processed length.
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if buf.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.common_state.send_some_plaintext(buf).await
let len = self.sendable_plaintext.append_limited_copy(buf);
Ok(len)
}
/// Write entire buffer into connection.
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
while pos < buf.len() {
pos += self.write_plaintext(&buf[pos..]).await?;
}
self.backend.flush().await?;
while let Some(msg) = self.backend.next_outgoing().await? {
self.queue_tls_message(msg);
}
Ok(pos)
/// Writes the entire plaintext `buf` into an internal buffer.
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
self.sendable_plaintext.append(buf.to_vec());
Ok(())
}
/// Read TLS content from `rd`. This method does internal
@@ -690,6 +691,11 @@ impl CommonState {
self.received_plaintext.is_empty()
}
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS
/// handshake.
///
@@ -782,15 +788,6 @@ impl CommonState {
}
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake
// messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections,
@@ -931,32 +928,6 @@ impl CommonState {
self.sendable_tls.write_to_async(wr).await
}
/// Encrypt and send some plaintext `data`. `limit` controls
/// whether the per-connection buffer limits apply.
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
let len = match limit {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true;
self.flush_plaintext().await
@@ -1012,15 +983,14 @@ impl CommonState {
self.sendable_tls.set_limit(limit);
}
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
async fn flush_plaintext(&mut self) -> Result<(), Error> {
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data {
return Ok(());
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_plain(&buf, Limit::No).await?;
self.send_appdata_encrypt(&buf, Limit::No).await?;
}
Ok(())

View File

@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
self.chunks.is_empty()
}
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing
pub(crate) fn len(&self) -> usize {
let mut len = 0;

View File

@@ -247,7 +247,8 @@ async fn servered_client_data_sent() {
let (mut client, mut server) =
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
client.flush_plaintext().await.unwrap();
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -286,7 +287,7 @@ async fn servered_both_data_sent() {
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
do_handshake(&mut client, &mut server).await;
@@ -432,7 +433,7 @@ async fn server_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
server.send_close_notify();
receive(&mut server, &mut client);
@@ -460,7 +461,8 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.flush_plaintext().await.unwrap();
client.send_close_notify().await.unwrap();
send(&mut client, &mut server);
@@ -487,7 +489,7 @@ async fn server_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
receive(&mut server, &mut client);
transfer_eof(&mut client);
@@ -518,7 +520,7 @@ async fn client_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.process_new_packets().await.unwrap();
send(&mut client, &mut server);
@@ -900,20 +902,9 @@ async fn client_respects_buffer_limit_pre_handshake() {
client.set_buffer_limit(Some(32));
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
12
);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
client.flush_plaintext().await.unwrap();
do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server);
@@ -953,20 +944,9 @@ async fn client_respects_buffer_limit_post_handshake() {
do_handshake(&mut client, &mut server).await;
client.set_buffer_limit(Some(48));
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
6
);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
client.flush_plaintext().await.unwrap();
send(&mut client, &mut server);
server.process_new_packets().unwrap();
@@ -1211,14 +1191,8 @@ async fn client_complete_io_for_write() {
do_handshake(&mut client, &mut server).await;
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
client.write_plaintext(b"01234567890123456789").unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let (rdlen, wrlen) = client
@@ -1350,7 +1324,8 @@ async fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap();
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ClientSession::new(&mut client);
@@ -1366,7 +1341,8 @@ async fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap();
client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{
let pipe = ClientSession::new(&mut client);
@@ -1385,7 +1361,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 0,
// };
// client.write_all_plaintext(b"hello").await.unwrap();
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert!(rc.is_err());
@@ -1402,7 +1380,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted,
// after: 1,
// };
// client.write_all_plaintext(b"hello").await.unwrap();
// client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world");
// assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1900,14 +1880,9 @@ async fn servered_write_for_client_appdata() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
do_handshake(&mut client, &mut server).await;
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -2019,11 +1994,10 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
async fn servered_write_for_client_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
client.write_all_plaintext(b"0123456789").await.unwrap();
client.write_all_plaintext(b"01234567890123456789").unwrap();
client.write_all_plaintext(b"0123456789").unwrap();
client.process_new_packets().await.unwrap();
{
let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap();

View File

@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
tlsn-core = { workspace = true }
tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true }
tlsn-cipher = { workspace = true }
futures-plex = { workspace = true }
serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true }

View File

@@ -22,6 +22,9 @@ use std::sync::LazyLock;
use semver::Version;
// Size for internal buffers.
const BUF_CAP: usize = 8 * 1024;
// Package version.
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::set::RangeSet;
use rangeset::RangeSet;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
@@ -21,6 +21,20 @@ impl<T> RangeMap<T>
where
T: Item,
{
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
let mut pos = 0;
for (idx, item) in &map {
assert!(
*idx >= pos,
"items must be sorted by index and non-overlapping"
);
pos = *idx + item.length();
}
Self { map }
}
/// Returns `true` if the map is empty.
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
@@ -33,6 +47,11 @@ where
.map(|(idx, item)| *idx..*idx + item.length())
}
/// Returns the length of the map.
pub(crate) fn len(&self) -> usize {
self.map.iter().map(|(_, item)| item.length()).sum()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
self.map
.iter()
@@ -58,7 +77,7 @@ where
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter() {
for idx in idx.iter_ranges() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,

View File

@@ -6,6 +6,11 @@ use mpz_core::Block;
#[cfg(not(tlsn_insecure))]
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_garble_core::Delta;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Key, Mac},
};
#[cfg(not(tlsn_insecure))]
use mpz_ot::cot::{DerandCOTReceiver, DerandCOTSender};
use mpz_ot::{
@@ -19,6 +24,8 @@ use tlsn_core::config::tls_commit::mpc::{MpcTlsConfig, NetworkSetting};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::transcript_internal::commit::encoding::{KeyStore, MacStore};
#[cfg(not(tlsn_insecure))]
pub(crate) type ProverMpc =
Garbler<DerandCOTSender<SharedRCOTSender<kos::Sender<co::Receiver>, Block>>>;
@@ -186,3 +193,41 @@ pub(crate) fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>
.translate(keys.server_write_mac_key)
.expect("VM memory should be consistent");
}
impl<T> KeyStore for Verifier<T> {
fn delta(&self) -> &Delta {
self.delta()
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
impl<T> MacStore for Prover<T> {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[cfg(tlsn_insecure)]
mod insecure {
use super::*;
use mpz_ideal_vm::IdealVm;
impl KeyStore for IdealVm {
fn delta(&self) -> &Delta {
unimplemented!("encodings not supported in insecure mode")
}
fn get_keys(&self, _data: Vector<U8>) -> Option<&[Key]> {
unimplemented!("encodings not supported in insecure mode")
}
}
impl MacStore for IdealVm {
fn get_macs(&self, _data: Vector<U8>) -> Option<&[Mac]> {
unimplemented!("encodings not supported in insecure mode")
}
}
}

View File

@@ -1,31 +1,31 @@
//! Prover.
mod client;
mod error;
mod future;
mod prove;
pub mod state;
pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput;
use crate::{
Role,
BUF_CAP, Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
prover::client::{MpcTlsClient, TlsOutput},
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use futures::{FutureExt, TryFutureExt};
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc;
use std::{
io::{Read, Write},
sync::Arc,
task::{Context, Poll},
};
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
config::{
prove::ProveConfig,
@@ -36,10 +36,9 @@ use tlsn_core::{
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -71,15 +70,16 @@ impl Prover<state::Initialized> {
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn commit(
self,
config: TlsCommitConfig,
socket: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Sends protocol configuration to verifier for compatibility check.
@@ -118,47 +118,48 @@ impl Prover<state::Initialized> {
debug!("mpc-tls setup complete");
Ok(Prover {
let prover = Prover {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
mpc_tls,
keys,
vm,
},
})
};
Ok(prover)
}
}
impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket.
/// Connects the prover.
///
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed and the TLS transcript is
/// committed.
/// Returns a connected prover, which can be used to read and write from/to
/// the active TLS connection.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn connect(
self,
config: TlsClientConfig,
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
) -> Result<Prover<state::Connected>, ProverError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mux_fut,
mpc_tls,
keys,
vm,
..
} = self.state;
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let decrypt = mpc_tls.is_decrypting();
let ServerName::Dns(server_name) = config.server_name();
let server_name =
@@ -195,102 +196,194 @@ impl Prover<state::CommitAccepted> {
rustls_config.with_no_client_auth()
};
let client = ClientConnection::new(
Arc::new(rustls_config),
Box::new(mpc_ctrl.clone()),
server_name,
)
.map_err(ProverError::config)?;
let client = ClientConnection::new(Arc::new(rustls_config), Box::new(mpc_tls), server_name)
.map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client);
let span = self.span.clone();
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
let mpc_tls = MpcTlsClient::new(keys, vm, span, client, decrypt);
mpc_ctrl.stop().await?;
Ok::<_, ProverError>(())
};
info!("starting MPC-TLS");
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
info!("finished MPC-TLS");
{
let mut vm = vm.try_lock().expect("VM should not be locked");
debug!("finalizing mpc");
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
.instrument(span)
});
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
let prover = Prover {
config: self.config,
span: self.span,
state: state::Connected {
mpc_duplex,
mux_ctrl,
mux_fut,
server_name: config.server_name().clone(),
tls_client: Box::new(mpc_tls),
output: None,
},
))
};
Ok(prover)
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Prover<state::Connected> {
/// Returns `true` if the prover wants to read TLS data from the server.
pub fn wants_read_tls(&self) -> bool {
self.state.tls_client.wants_read_tls()
}
/// Returns `true` if the prover wants to write TLS data to the server.
pub fn wants_write_tls(&self) -> bool {
self.state.tls_client.wants_write_tls()
}
/// Reads TLS data from the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the TLS data from.
pub fn read_tls(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.read_tls(buf)
}
/// Writes TLS data for the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer to write the TLS data to.
pub fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.write_tls(buf)
}
/// Returns `true` if the prover wants to read plaintext data.
pub fn wants_read(&self) -> bool {
self.state.tls_client.wants_read()
}
/// Returns `true` if the prover wants to write plaintext data.
pub fn wants_write(&self) -> bool {
self.state.tls_client.wants_write()
}
/// Reads plaintext data from the server into the provided buffer.
///
/// # Arguments
///
/// * `buf` - The buffer where the plaintext data gets written to.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProverError> {
self.state.tls_client.read(buf)
}
/// Writes plaintext data to be sent to the server.
///
/// # Arguments
///
/// * `buf` - The buffer to read the plaintext data from.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProverError> {
self.state.tls_client.write(buf)
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Closes the connection from the client side.
pub fn client_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.client_close()
}
/// Closes the connection from the server side.
pub fn server_close(&mut self) -> Result<(), ProverError> {
self.state.tls_client.server_close()
}
/// Enables or disables the decryption of data from the server until the
/// server has closed the connection.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), ProverError> {
self.state.tls_client.enable_decryption(enable)
}
/// Returns `true` if decryption of TLS traffic from the server is active.
pub fn is_decrypting(&self) -> bool {
self.state.tls_client.is_decrypting()
}
/// Polls the prover to make progress.
///
/// # Arguments
///
/// * `cx` - The async context.
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), ProverError>> {
let _ = self.state.mux_fut.poll_unpin(cx)?;
match self.state.tls_client.poll(cx)? {
Poll::Ready(output) => {
let _ = self.state.mux_fut.poll_unpin(cx)?;
self.state.output = Some(output);
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex: self.state.mpc_duplex,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
},
};
Ok(prover)
}
}
@@ -305,6 +398,24 @@ impl Prover<state::Committed> {
&self.state.transcript
}
/// Writes bytes for the verifier into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the prover from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Proves information to the verifier.
///
/// # Arguments
@@ -367,41 +478,19 @@ impl Prover<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), ProverError> {
let state::Committed {
mux_ctrl, mux_fut, ..
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
} = self.state;
// Wait for the verifier to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
}
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -0,0 +1,63 @@
//! Provides a TLS client.
use crate::mpz::ProverZk;
use mpc_tls::SessionKeys;
use std::task::{Context, Poll};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Enables or disables decryption of TLS traffic sent by the server.
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error>;
/// Returns `true` if decryption of TLS traffic from the server is active.
fn is_decrypting(&self) -> bool;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -0,0 +1,391 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{collections::VecDeque, pin::Pin, sync::Arc, task::Poll};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture =
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
pub(crate) struct MpcTlsClient {
state: State,
decrypt: bool,
cmds: VecDeque<Command>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum Command {
ClientClose,
ServerClose,
Decrypt(bool),
}
enum State {
Start {
inner: Box<InnerState>,
},
Active {
inner: Box<InnerState>,
},
Busy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
CloseActive {
inner: Box<InnerState>,
},
CloseBusy {
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
tls: ClientConnection,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_stopped: false,
};
let decrypt = tls.backend().is_decrypting();
Self {
state: State::Start {
inner: Box::new(inner),
},
decrypt,
cmds: VecDeque::default(),
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } | State::CloseActive { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ClientClose);
Ok(())
}
fn server_close(&mut self) -> Result<(), Self::Error> {
self.cmds.push_back(Command::ServerClose);
Ok(())
}
fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> {
self.cmds.push_back(Command::Decrypt(enable));
Ok(())
}
fn is_decrypting(&self) -> bool {
self.decrypt
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start { inner } => {
trace!("inner client is starting");
self.state = State::Busy {
fut: Box::pin(inner.start()),
};
self.poll(cx)
}
State::Active { mut inner } => {
trace!("inner client is active");
if !inner.tls.is_handshaking()
&& let Some(cmd) = self.cmds.pop_front()
{
match cmd {
Command::ClientClose => {
self.state = State::Busy {
fut: Box::pin(inner.client_close()),
};
}
Command::ServerClose => {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
Command::Decrypt(enable) => {
inner.tls.backend_mut().enable_decryption(enable)?;
self.decrypt = enable;
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
}
} else {
self.state = State::Busy {
fut: Box::pin(inner.run()),
};
}
self.poll(cx)
}
State::Busy { mut fut } => {
trace!("inner client is busy");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active { inner };
}
Poll::Pending => self.state = State::Busy { fut },
}
Poll::Pending
}
State::CloseActive { mut inner } => {
trace!("inner client is close active");
if let Some((ctx, transcript)) = inner.tls.backend_mut().finish() {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
} else {
self.state = State::CloseBusy {
fut: Box::pin(inner.server_close()),
};
}
self.poll(cx)
}
State::CloseBusy { mut fut } => {
trace!("inner client is busy closing");
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::CloseActive { inner };
}
Poll::Pending => self.state = State::CloseBusy { fut },
}
Poll::Pending
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
Poll::Pending
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_stopped: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
debug!("sending close notify");
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
self.tls.server_closed().await?;
self.mpc_stopped = true;
debug!("closed connection serverside");
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -2,6 +2,8 @@ use std::{error::Error, fmt};
use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Prover`](crate::prover::Prover).
#[derive(Debug, thiserror::Error)]
pub struct ProverError {
@@ -47,6 +49,13 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
}
#[derive(Debug)]
@@ -56,6 +65,7 @@ enum ErrorKind {
Zk,
Config,
Commit,
State,
}
impl fmt::Display for ProverError {
@@ -68,6 +78,7 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
}
if let Some(source) = &self.source {
@@ -84,8 +95,8 @@ impl From<std::io::Error> for ProverError {
}
}
impl From<tls_client_async::ConnectionError> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self {
impl From<tls_client::Error> for ProverError {
fn from(e: tls_client::Error) -> Self {
Self::new(ErrorKind::Io, e)
}
}
@@ -107,3 +118,15 @@ impl From<MpcTlsError> for ProverError {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}
impl From<ProverError> for std::io::Error {
fn from(value: ProverError) -> Self {
Self::other(value)
}
}

View File

@@ -1,32 +0,0 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
ProverOutput,
config::prove::ProveConfig,
@@ -13,10 +13,17 @@ use tlsn_core::{
use crate::{
prover::ProverError,
transcript_internal::{TranscriptRefs, auth::prove_plaintext, commit::hash::prove_hash},
transcript_internal::{
TranscriptRefs,
auth::prove_plaintext,
commit::{
encoding::{self, MacStore},
hash::prove_hash,
},
},
};
pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
@@ -38,6 +45,13 @@ pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
commit_config
.iter_encoding()
.for_each(|(direction, idx)| match direction {
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
}
let transcript_refs = TranscriptRefs {
@@ -88,6 +102,45 @@ pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_encoding()
{
let mut sent_ranges = RangeSet::default();
let mut recv_ranges = RangeSet::default();
for (dir, idx) in commit_config.iter_encoding() {
match dir {
Direction::Sent => sent_ranges.union_mut(idx),
Direction::Received => recv_ranges.union_mut(idx),
}
}
let sent_map = transcript_refs
.sent
.index(&sent_ranges)
.expect("indices are valid");
let recv_map = transcript_refs
.recv
.index(&recv_ranges)
.expect("indices are valid");
let (commitment, tree) = encoding::receive(
ctx,
vm,
*commit_config.encoding_hash_alg(),
&sent_map,
&recv_map,
commit_config.iter_encoding(),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {

View File

@@ -2,6 +2,7 @@
use std::sync::Arc;
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -14,6 +15,10 @@ use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
};
/// Entry state
@@ -24,6 +29,7 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
@@ -33,8 +39,21 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted);
/// State during the MPC-TLS connection.
pub struct Connected {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
pub(crate) output: Option<TlsOutput>,
}
opaque_debug::implement!(Connected);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -52,11 +71,13 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for CommitAccepted {}
impl ProverState for Connected {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Connected {}
impl Sealed for super::Committed {}
}

View File

@@ -12,7 +12,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
use rangeset::{Difference, RangeSet, Union};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal).into_set()
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal).into_set();
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal).into_set()
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
ciphertext,
})
} else {
let private = commit.difference(reveal).into_set();
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -175,13 +175,15 @@ fn alloc_plaintext(
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_ciphertext<'a>(
@@ -210,13 +212,15 @@ fn alloc_ciphertext<'a>(
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_keystream<'a>(
@@ -229,7 +233,7 @@ fn alloc_keystream<'a>(
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter();
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
@@ -504,7 +508,7 @@ mod tests {
for record in records {
let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter() {
for mut range in ranges.iter_ranges() {
range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len);
if range.start < range.end {

View File

@@ -1,3 +1,4 @@
//! Plaintext commitment and proof of encryption.
pub(crate) mod encoding;
pub(crate) mod hash;

View File

@@ -0,0 +1,267 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_common::Context;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
EncodingTree, EncodingTreeError, new_encoder,
},
},
};
use crate::{
map::{Item, RangeMap},
transcript_internal::ReferenceMap,
};
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers encodings for the provided plaintext ranges.
pub(crate) async fn transfer<K: KeyStore>(
ctx: &mut Context,
store: &K,
sent: &ReferenceMap,
recv: &ReferenceMap,
) -> Result<(EncoderSecret, EncodingCommitment), EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
let encoder = new_encoder(&secret);
// Collects the encodings for the provided plaintext ranges.
fn collect_encodings(
encoder: &impl Encoder,
store: &impl KeyStore,
direction: Direction,
map: &ReferenceMap,
) -> Vec<u8> {
let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
for (range, chunk) in map.iter() {
let start = encodings.len();
encoder.encode_range(direction, range, &mut encodings);
let keys = store
.get_keys(*chunk)
.expect("keys are present for provided plaintext ranges");
encodings[start..]
.iter_mut()
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
.for_each(|(encoding, key)| {
*encoding ^= *key;
});
}
encodings
}
let encodings = Encodings {
sent: collect_encodings(&encoder, store, Direction::Sent, sent),
recv: collect_encodings(&encoder, store, Direction::Received, recv),
};
let frame_limit = ctx
.io()
.limit()
.saturating_add(encodings.sent.len() + encodings.recv.len());
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
let root = ctx.io_mut().expect_next().await?;
Ok((secret, EncodingCommitment { root }))
}
/// Receives and commits to the encodings for the provided plaintext ranges.
pub(crate) async fn receive<M: MacStore>(
ctx: &mut Context,
store: &M,
hash_alg: HashAlgId,
sent: &ReferenceMap,
recv: &ReferenceMap,
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
}
};
let (sent_len, recv_len) = (sent.len(), recv.len());
let frame_limit = ctx
.io()
.limit()
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
if encodings.sent.len() != sent_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_len,
got: encodings.sent.len() / ENCODING_SIZE,
}
.into());
}
if encodings.recv.len() != recv_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_len,
got: encodings.recv.len() / ENCODING_SIZE,
}
.into());
}
// Collects a map of plaintext ranges to their encodings.
fn collect_map(
store: &impl MacStore,
mut encodings: Vec<u8>,
map: &ReferenceMap,
) -> RangeMap<EncodingSlice> {
let mut encoding_map = Vec::new();
let mut pos = 0;
for (range, chunk) in map.iter() {
let macs = store
.get_macs(*chunk)
.expect("MACs are present for provided plaintext ranges");
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
encoding
.iter_mut()
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
.for_each(|(encoding, mac)| {
*encoding ^= *mac;
});
encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
pos += range.len() * ENCODING_SIZE;
}
RangeMap::new(encoding_map)
}
let provider = Provider {
sent: collect_map(store, encodings.sent, sent),
recv: collect_map(store, encodings.recv, recv),
};
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
ctx.io_mut().send(root.clone()).await?;
let commitment = EncodingCommitment { root };
Ok((commitment, tree))
}
pub(crate) trait KeyStore {
fn delta(&self) -> &Delta;
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
}
pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
}
#[derive(Debug)]
struct Provider {
sent: RangeMap<EncodingSlice>,
recv: RangeMap<EncodingSlice>,
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
dest.extend_from_slice(encoding);
Ok(())
}
}
#[derive(Debug)]
struct EncodingSlice(Vec<u8>);
impl From<&[u8]> for EncodingSlice {
fn from(value: &[u8]) -> Self {
Self(value.to_vec())
}
}
impl Item for EncodingSlice {
type Slice<'a>
= &'a [u8]
where
Self: 'a;
fn length(&self) -> usize {
self.0.len() / ENCODING_SIZE
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.0
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("unsupported hash algorithm: {0}")]
UnsupportedHashAlgorithm(HashAlgId),
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::set::RangeSet;
use rangeset::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
@@ -155,7 +155,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
@@ -176,7 +176,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
@@ -201,7 +201,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;

View File

@@ -10,16 +10,17 @@ pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{
Role,
BUF_CAP, Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use futures::TryFutureExt;
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use std::io::{Read, Write};
use tlsn_core::{
config::{
prove::ProveRequest,
@@ -68,11 +69,10 @@ impl Verifier<state::Initialized> {
///
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
pub async fn commit(self) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_b, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -102,6 +102,7 @@ impl Verifier<state::Initialized> {
config: self.config,
span: self.span,
state: state::CommitStart {
mpc_duplex: duplex_a,
mux_ctrl,
mux_fut,
ctx,
@@ -121,6 +122,7 @@ impl Verifier<state::CommitStart> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -151,6 +153,7 @@ impl Verifier<state::CommitStart> {
config: self.config,
span: self.span,
state: state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mux_fut,
mpc_tls,
@@ -182,6 +185,24 @@ impl Verifier<state::CommitStart> {
Ok(())
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::CommitAccepted> {
@@ -189,6 +210,7 @@ impl Verifier<state::CommitAccepted> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -245,6 +267,7 @@ impl Verifier<state::CommitAccepted> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -254,6 +277,24 @@ impl Verifier<state::CommitAccepted> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}
impl Verifier<state::Committed> {
@@ -266,6 +307,7 @@ impl Verifier<state::Committed> {
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -286,6 +328,7 @@ impl Verifier<state::Committed> {
config: self.config,
span: self.span,
state: state::Verify {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -299,17 +342,39 @@ impl Verifier<state::Committed> {
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
/// Closes the connection with the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), VerifierError> {
let state::Committed {
mux_ctrl, mux_fut, ..
mut mpc_duplex,
mux_ctrl,
mux_fut,
..
} = self.state;
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
futures::AsyncWriteExt::close(&mut mpc_duplex).await?;
}
Ok(())
@@ -327,6 +392,7 @@ impl Verifier<state::Verify> {
self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -362,6 +428,7 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -379,6 +446,7 @@ impl Verifier<state::Verify> {
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify {
mpc_duplex,
mux_ctrl,
mut mux_fut,
mut ctx,
@@ -396,6 +464,7 @@ impl Verifier<state::Verify> {
config: self.config,
span: self.span,
state: state::Committed {
mpc_duplex,
mux_ctrl,
mux_fut,
ctx,
@@ -405,4 +474,22 @@ impl Verifier<state::Verify> {
},
})
}
/// Writes bytes for the prover into a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn write_mpc(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.read(buf)
}
/// Reads bytes for the verifier from a buffer.
///
/// # Arguments
///
/// * `buf` - The buffer.
pub fn read_mpc(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.state.mpc_duplex.write(buf)
}
}

View File

@@ -2,6 +2,8 @@ use std::{error::Error, fmt};
use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Verifier`](crate::verifier::Verifier).
#[derive(Debug, thiserror::Error)]
pub struct VerifierError {
@@ -55,6 +57,7 @@ enum ErrorKind {
Config,
Mpc,
Zk,
Commit,
Verify,
}
@@ -67,6 +70,7 @@ impl fmt::Display for VerifierError {
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Verify => f.write_str("verification error")?,
}
@@ -101,3 +105,9 @@ impl From<MpcTlsError> for VerifierError {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -3,6 +3,7 @@
use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::{
@@ -25,6 +26,7 @@ opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover.
pub struct CommitStart {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -36,6 +38,7 @@ opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower,
@@ -47,6 +50,7 @@ opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
@@ -59,6 +63,7 @@ opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify {
pub(crate) mpc_duplex: DuplexStream,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
VerifierOutput,
config::prove::ProveRequest,
@@ -14,12 +14,19 @@ use tlsn_core::{
};
use crate::{
transcript_internal::{TranscriptRefs, auth::verify_plaintext, commit::hash::verify_hash},
transcript_internal::{
TranscriptRefs,
auth::verify_plaintext,
commit::{
encoding::{self, KeyStore},
hash::verify_hash,
},
},
verifier::VerifierError,
};
#[allow(clippy::too_many_arguments)]
pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
@@ -87,6 +94,11 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
if let Some((sent, recv)) = commit_config.encoding() {
commit_sent.union_mut(sent);
commit_recv.union_mut(recv);
}
}
let (sent_refs, sent_proof) = verify_plaintext(
@@ -139,6 +151,24 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
sent_proof.verify().map_err(VerifierError::verify)?;
recv_proof.verify().map_err(VerifierError::verify)?;
let mut encoder_secret = None;
if let Some(commit_config) = request.transcript_commit()
&& let Some((sent, recv)) = commit_config.encoding()
{
let sent_map = transcript_refs
.sent
.index(sent)
.expect("ranges were authenticated");
let recv_map = transcript_refs
.recv
.index(recv)
.expect("ranges were authenticated");
let (secret, commitment) = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
encoder_secret = Some(secret);
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
@@ -148,6 +178,7 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
Ok(VerifierOutput {
server_name,
transcript: request.reveal().is_some().then_some(transcript),
encoder_secret,
transcript_commitments,
})
}

View File

@@ -1,4 +1,5 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use tlsn::{
config::{
prove::ProveConfig,
@@ -8,9 +9,12 @@ use tlsn::{
verifier::VerifierConfig,
},
connection::ServerName,
hash::HashAlgId,
hash::{HashAlgId, HashProvider},
prover::Prover,
transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitmentKind},
transcript::{
Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment,
TranscriptCommitmentKind, TranscriptSecret,
},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
@@ -38,7 +42,7 @@ async fn test() {
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
let ((_full_transcript, _prover_output), verifier_output) =
let ((full_transcript, prover_output), verifier_output) =
tokio::join!(prover(socket_0), verifier(socket_1));
let partial_transcript = verifier_output.transcript.unwrap();
@@ -47,13 +51,65 @@ async fn test() {
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete());
assert_eq!(
partial_transcript.sent_authed().iter().next().unwrap(),
partial_transcript
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);
assert_eq!(
partial_transcript.received_authed().iter().next().unwrap(),
partial_transcript
.received_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);
let encoding_tree = prover_output
.transcript_secrets
.iter()
.find_map(|secret| {
if let TranscriptSecret::Encoding(tree) = secret {
Some(tree)
} else {
None
}
})
.unwrap();
let encoding_commitment = prover_output
.transcript_commitments
.iter()
.find_map(|commitment| {
if let TranscriptCommitment::Encoding(commitment) = commitment {
Some(commitment)
} else {
None
}
})
.unwrap();
let prove_sent = RangeSet::from(1..full_transcript.sent().len() - 1);
let prove_recv = RangeSet::from(1..full_transcript.received().len() - 1);
let idxs = [
(Direction::Sent, prove_sent.clone()),
(Direction::Received, prove_recv.clone()),
];
let proof = encoding_tree.proof(idxs.iter()).unwrap();
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&verifier_output.encoder_secret.unwrap(),
encoding_commitment,
full_transcript.sent(),
full_transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, prove_sent);
assert_eq!(auth_recv, prove_recv);
}
#[instrument(skip(verifier_socket))]
@@ -115,21 +171,25 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
let kind = TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
};
builder
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
.unwrap();
builder
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
.unwrap();
for kind in [
TranscriptCommitmentKind::Encoding,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
] {
builder
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
.unwrap();
builder
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
.unwrap();
}
let transcript_commit = builder.build().unwrap();

View File

@@ -23,7 +23,6 @@ no-bundler = ["web-spawn/no-bundler"]
tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }
bincode = { workspace = true }

View File

@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
Self {
sent: value.sent_unsafe().to_vec(),
sent_authed: value.sent_authed().iter().collect(),
sent_authed: value.sent_authed().iter_ranges().collect(),
recv: value.received_unsafe().to_vec(),
recv_authed: value.received_authed().iter().collect(),
recv_authed: value.received_authed().iter_ranges().collect(),
}
}
}